forked from cardosofelipe/fast-next-template
feat(memory): #87 project setup & core architecture
Implements Sub-Issue #87 of Issue #62 (Agent Memory System). Core infrastructure: - memory/types.py: Type definitions for all memory types (Working, Episodic, Semantic, Procedural) with enums for MemoryType, ScopeLevel, Outcome - memory/config.py: MemorySettings with MEM_ env prefix, thread-safe singleton - memory/exceptions.py: Comprehensive exception hierarchy for memory operations - memory/manager.py: MemoryManager facade with placeholder methods Directory structure: - working/: Working memory (Redis/in-memory) - to be implemented in #89 - episodic/: Episodic memory (experiences) - to be implemented in #90 - semantic/: Semantic memory (facts) - to be implemented in #91 - procedural/: Procedural memory (skills) - to be implemented in #92 - scoping/: Scope management - to be implemented in #93 - indexing/: Vector indexing - to be implemented in #94 - consolidation/: Memory consolidation - to be implemented in #95 Tests: 71 unit tests for config, types, and exceptions Docs: Comprehensive implementation plan at docs/architecture/memory-system-plan.md 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
136
backend/app/services/memory/__init__.py
Normal file
136
backend/app/services/memory/__init__.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Agent Memory System
|
||||
|
||||
Multi-tier cognitive memory for AI agents, providing:
|
||||
- Working Memory: Session-scoped ephemeral state (Redis/In-memory)
|
||||
- Episodic Memory: Experiential records of past tasks (PostgreSQL)
|
||||
- Semantic Memory: Learned facts and knowledge (PostgreSQL + pgvector)
|
||||
- Procedural Memory: Learned skills and procedures (PostgreSQL)
|
||||
|
||||
Usage:
|
||||
from app.services.memory import (
|
||||
MemoryManager,
|
||||
MemorySettings,
|
||||
get_memory_settings,
|
||||
MemoryType,
|
||||
ScopeLevel,
|
||||
)
|
||||
|
||||
# Create a manager for a session
|
||||
manager = MemoryManager.for_session(
|
||||
session_id="sess-123",
|
||||
project_id=uuid,
|
||||
)
|
||||
|
||||
async with manager:
|
||||
# Working memory
|
||||
await manager.set_working("key", {"data": "value"})
|
||||
value = await manager.get_working("key")
|
||||
|
||||
# Episodic memory
|
||||
episode = await manager.record_episode(episode_data)
|
||||
similar = await manager.search_episodes("query")
|
||||
|
||||
# Semantic memory
|
||||
fact = await manager.store_fact(fact_data)
|
||||
facts = await manager.search_facts("query")
|
||||
|
||||
# Procedural memory
|
||||
procedure = await manager.record_procedure(procedure_data)
|
||||
procedures = await manager.find_procedures("context")
|
||||
"""
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
MemorySettings,
|
||||
get_default_settings,
|
||||
get_memory_settings,
|
||||
reset_memory_settings,
|
||||
)
|
||||
|
||||
# Exceptions
|
||||
from .exceptions import (
|
||||
CheckpointError,
|
||||
EmbeddingError,
|
||||
MemoryCapacityError,
|
||||
MemoryConflictError,
|
||||
MemoryConsolidationError,
|
||||
MemoryError,
|
||||
MemoryExpiredError,
|
||||
MemoryNotFoundError,
|
||||
MemoryRetrievalError,
|
||||
MemoryScopeError,
|
||||
MemorySerializationError,
|
||||
MemoryStorageError,
|
||||
)
|
||||
|
||||
# Manager
|
||||
from .manager import MemoryManager
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Fact,
|
||||
FactCreate,
|
||||
MemoryItem,
|
||||
MemoryStats,
|
||||
MemoryStore,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
ProcedureCreate,
|
||||
RetrievalResult,
|
||||
ScopeContext,
|
||||
ScopeLevel,
|
||||
Step,
|
||||
TaskState,
|
||||
WorkingMemoryItem,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CheckpointError",
|
||||
"ConsolidationStatus",
|
||||
"ConsolidationType",
|
||||
"EmbeddingError",
|
||||
"Episode",
|
||||
"EpisodeCreate",
|
||||
"Fact",
|
||||
"FactCreate",
|
||||
"MemoryCapacityError",
|
||||
"MemoryConflictError",
|
||||
"MemoryConsolidationError",
|
||||
# Exceptions
|
||||
"MemoryError",
|
||||
"MemoryExpiredError",
|
||||
"MemoryItem",
|
||||
# Manager
|
||||
"MemoryManager",
|
||||
"MemoryNotFoundError",
|
||||
"MemoryRetrievalError",
|
||||
"MemoryScopeError",
|
||||
"MemorySerializationError",
|
||||
# Configuration
|
||||
"MemorySettings",
|
||||
"MemoryStats",
|
||||
"MemoryStorageError",
|
||||
# Types - Abstract
|
||||
"MemoryStore",
|
||||
# Types - Enums
|
||||
"MemoryType",
|
||||
"Outcome",
|
||||
"Procedure",
|
||||
"ProcedureCreate",
|
||||
"RetrievalResult",
|
||||
# Types - Data Classes
|
||||
"ScopeContext",
|
||||
"ScopeLevel",
|
||||
"Step",
|
||||
"TaskState",
|
||||
"WorkingMemoryItem",
|
||||
"get_default_settings",
|
||||
"get_memory_settings",
|
||||
"reset_memory_settings",
|
||||
]
|
||||
410
backend/app/services/memory/config.py
Normal file
410
backend/app/services/memory/config.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Memory System Configuration.
|
||||
|
||||
Provides Pydantic settings for the Agent Memory System,
|
||||
including storage backends, capacity limits, and consolidation policies.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class MemorySettings(BaseSettings):
|
||||
"""
|
||||
Configuration for the Agent Memory System.
|
||||
|
||||
All settings can be overridden via environment variables
|
||||
with the MEM_ prefix.
|
||||
"""
|
||||
|
||||
# Working Memory Settings
|
||||
working_memory_backend: str = Field(
|
||||
default="redis",
|
||||
description="Backend for working memory: 'redis' or 'memory'",
|
||||
)
|
||||
working_memory_default_ttl_seconds: int = Field(
|
||||
default=3600,
|
||||
ge=60,
|
||||
le=86400,
|
||||
description="Default TTL for working memory items (1 hour default)",
|
||||
)
|
||||
working_memory_max_items_per_session: int = Field(
|
||||
default=1000,
|
||||
ge=100,
|
||||
le=100000,
|
||||
description="Maximum items per session in working memory",
|
||||
)
|
||||
working_memory_max_value_size_bytes: int = Field(
|
||||
default=1048576, # 1MB
|
||||
ge=1024,
|
||||
le=104857600, # 100MB
|
||||
description="Maximum size of a single value in working memory",
|
||||
)
|
||||
working_memory_checkpoint_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable checkpointing for working memory recovery",
|
||||
)
|
||||
|
||||
# Redis Settings (for working memory)
|
||||
redis_url: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis connection URL",
|
||||
)
|
||||
redis_prefix: str = Field(
|
||||
default="mem",
|
||||
description="Redis key prefix for memory items",
|
||||
)
|
||||
redis_connection_timeout_seconds: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=60,
|
||||
description="Redis connection timeout",
|
||||
)
|
||||
|
||||
# Episodic Memory Settings
|
||||
episodic_max_episodes_per_project: int = Field(
|
||||
default=10000,
|
||||
ge=100,
|
||||
le=1000000,
|
||||
description="Maximum episodes to retain per project",
|
||||
)
|
||||
episodic_default_importance: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Default importance score for new episodes",
|
||||
)
|
||||
episodic_retention_days: int = Field(
|
||||
default=365,
|
||||
ge=7,
|
||||
le=3650,
|
||||
description="Days to retain episodes before archival",
|
||||
)
|
||||
|
||||
# Semantic Memory Settings
|
||||
semantic_max_facts_per_project: int = Field(
|
||||
default=50000,
|
||||
ge=1000,
|
||||
le=10000000,
|
||||
description="Maximum facts to retain per project",
|
||||
)
|
||||
semantic_confidence_decay_days: int = Field(
|
||||
default=90,
|
||||
ge=7,
|
||||
le=365,
|
||||
description="Days until confidence decays by 50%",
|
||||
)
|
||||
semantic_min_confidence: float = Field(
|
||||
default=0.1,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum confidence before fact is pruned",
|
||||
)
|
||||
|
||||
# Procedural Memory Settings
|
||||
procedural_max_procedures_per_project: int = Field(
|
||||
default=1000,
|
||||
ge=10,
|
||||
le=100000,
|
||||
description="Maximum procedures per project",
|
||||
)
|
||||
procedural_min_success_rate: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum success rate before procedure is pruned",
|
||||
)
|
||||
procedural_min_uses_before_suggest: int = Field(
|
||||
default=3,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Minimum uses before procedure is suggested",
|
||||
)
|
||||
|
||||
# Embedding Settings
|
||||
embedding_model: str = Field(
|
||||
default="text-embedding-3-small",
|
||||
description="Model to use for embeddings",
|
||||
)
|
||||
embedding_dimensions: int = Field(
|
||||
default=1536,
|
||||
ge=256,
|
||||
le=4096,
|
||||
description="Embedding vector dimensions",
|
||||
)
|
||||
embedding_batch_size: int = Field(
|
||||
default=100,
|
||||
ge=1,
|
||||
le=1000,
|
||||
description="Batch size for embedding generation",
|
||||
)
|
||||
embedding_cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable caching of embeddings",
|
||||
)
|
||||
|
||||
# Retrieval Settings
|
||||
retrieval_default_limit: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Default limit for retrieval queries",
|
||||
)
|
||||
retrieval_max_limit: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Maximum limit for retrieval queries",
|
||||
)
|
||||
retrieval_min_similarity: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum similarity score for retrieval",
|
||||
)
|
||||
|
||||
# Consolidation Settings
|
||||
consolidation_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable automatic memory consolidation",
|
||||
)
|
||||
consolidation_batch_size: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=1000,
|
||||
description="Batch size for consolidation jobs",
|
||||
)
|
||||
consolidation_schedule_cron: str = Field(
|
||||
default="0 3 * * *",
|
||||
description="Cron expression for nightly consolidation (3 AM)",
|
||||
)
|
||||
consolidation_working_to_episodic_delay_minutes: int = Field(
|
||||
default=30,
|
||||
ge=5,
|
||||
le=1440,
|
||||
description="Minutes after session end before consolidating to episodic",
|
||||
)
|
||||
|
||||
# Pruning Settings
|
||||
pruning_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable automatic memory pruning",
|
||||
)
|
||||
pruning_min_age_days: int = Field(
|
||||
default=7,
|
||||
ge=1,
|
||||
le=365,
|
||||
description="Minimum age before memory can be pruned",
|
||||
)
|
||||
pruning_importance_threshold: float = Field(
|
||||
default=0.2,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance threshold below which memory can be pruned",
|
||||
)
|
||||
|
||||
# Caching Settings
|
||||
cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable caching for memory retrieval",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=300,
|
||||
ge=10,
|
||||
le=3600,
|
||||
description="Cache TTL for retrieval results",
|
||||
)
|
||||
cache_max_items: int = Field(
|
||||
default=10000,
|
||||
ge=100,
|
||||
le=1000000,
|
||||
description="Maximum items in memory cache",
|
||||
)
|
||||
|
||||
# Performance Settings
|
||||
max_retrieval_time_ms: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=5000,
|
||||
description="Target maximum retrieval time in milliseconds",
|
||||
)
|
||||
parallel_retrieval: bool = Field(
|
||||
default=True,
|
||||
description="Enable parallel retrieval from multiple memory types",
|
||||
)
|
||||
max_parallel_retrievals: int = Field(
|
||||
default=4,
|
||||
ge=1,
|
||||
le=10,
|
||||
description="Maximum concurrent retrieval operations",
|
||||
)
|
||||
|
||||
@field_validator("working_memory_backend")
|
||||
@classmethod
|
||||
def validate_backend(cls, v: str) -> str:
|
||||
"""Validate working memory backend."""
|
||||
valid_backends = {"redis", "memory"}
|
||||
if v not in valid_backends:
|
||||
raise ValueError(f"backend must be one of: {valid_backends}")
|
||||
return v
|
||||
|
||||
@field_validator("embedding_model")
|
||||
@classmethod
|
||||
def validate_embedding_model(cls, v: str) -> str:
|
||||
"""Validate embedding model name."""
|
||||
valid_models = {
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large",
|
||||
"text-embedding-ada-002",
|
||||
}
|
||||
if v not in valid_models:
|
||||
raise ValueError(f"embedding_model must be one of: {valid_models}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_limits(self) -> "MemorySettings":
|
||||
"""Validate that limits are consistent."""
|
||||
if self.retrieval_default_limit > self.retrieval_max_limit:
|
||||
raise ValueError(
|
||||
f"retrieval_default_limit ({self.retrieval_default_limit}) "
|
||||
f"cannot exceed retrieval_max_limit ({self.retrieval_max_limit})"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_working_memory_config(self) -> dict[str, Any]:
|
||||
"""Get working memory configuration as a dictionary."""
|
||||
return {
|
||||
"backend": self.working_memory_backend,
|
||||
"default_ttl_seconds": self.working_memory_default_ttl_seconds,
|
||||
"max_items_per_session": self.working_memory_max_items_per_session,
|
||||
"max_value_size_bytes": self.working_memory_max_value_size_bytes,
|
||||
"checkpoint_enabled": self.working_memory_checkpoint_enabled,
|
||||
}
|
||||
|
||||
def get_redis_config(self) -> dict[str, Any]:
|
||||
"""Get Redis configuration as a dictionary."""
|
||||
return {
|
||||
"url": self.redis_url,
|
||||
"prefix": self.redis_prefix,
|
||||
"connection_timeout_seconds": self.redis_connection_timeout_seconds,
|
||||
}
|
||||
|
||||
def get_embedding_config(self) -> dict[str, Any]:
|
||||
"""Get embedding configuration as a dictionary."""
|
||||
return {
|
||||
"model": self.embedding_model,
|
||||
"dimensions": self.embedding_dimensions,
|
||||
"batch_size": self.embedding_batch_size,
|
||||
"cache_enabled": self.embedding_cache_enabled,
|
||||
}
|
||||
|
||||
def get_consolidation_config(self) -> dict[str, Any]:
|
||||
"""Get consolidation configuration as a dictionary."""
|
||||
return {
|
||||
"enabled": self.consolidation_enabled,
|
||||
"batch_size": self.consolidation_batch_size,
|
||||
"schedule_cron": self.consolidation_schedule_cron,
|
||||
"working_to_episodic_delay_minutes": (
|
||||
self.consolidation_working_to_episodic_delay_minutes
|
||||
),
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert settings to dictionary for logging/debugging."""
|
||||
return {
|
||||
"working_memory": self.get_working_memory_config(),
|
||||
"redis": self.get_redis_config(),
|
||||
"episodic": {
|
||||
"max_episodes_per_project": self.episodic_max_episodes_per_project,
|
||||
"default_importance": self.episodic_default_importance,
|
||||
"retention_days": self.episodic_retention_days,
|
||||
},
|
||||
"semantic": {
|
||||
"max_facts_per_project": self.semantic_max_facts_per_project,
|
||||
"confidence_decay_days": self.semantic_confidence_decay_days,
|
||||
"min_confidence": self.semantic_min_confidence,
|
||||
},
|
||||
"procedural": {
|
||||
"max_procedures_per_project": self.procedural_max_procedures_per_project,
|
||||
"min_success_rate": self.procedural_min_success_rate,
|
||||
"min_uses_before_suggest": self.procedural_min_uses_before_suggest,
|
||||
},
|
||||
"embedding": self.get_embedding_config(),
|
||||
"retrieval": {
|
||||
"default_limit": self.retrieval_default_limit,
|
||||
"max_limit": self.retrieval_max_limit,
|
||||
"min_similarity": self.retrieval_min_similarity,
|
||||
},
|
||||
"consolidation": self.get_consolidation_config(),
|
||||
"pruning": {
|
||||
"enabled": self.pruning_enabled,
|
||||
"min_age_days": self.pruning_min_age_days,
|
||||
"importance_threshold": self.pruning_importance_threshold,
|
||||
},
|
||||
"cache": {
|
||||
"enabled": self.cache_enabled,
|
||||
"ttl_seconds": self.cache_ttl_seconds,
|
||||
"max_items": self.cache_max_items,
|
||||
},
|
||||
"performance": {
|
||||
"max_retrieval_time_ms": self.max_retrieval_time_ms,
|
||||
"parallel_retrieval": self.parallel_retrieval,
|
||||
"max_parallel_retrievals": self.max_parallel_retrievals,
|
||||
},
|
||||
}
|
||||
|
||||
model_config = {
|
||||
"env_prefix": "MEM_",
|
||||
"env_file": ".env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
# Thread-safe singleton pattern
|
||||
_settings: MemorySettings | None = None
|
||||
_settings_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_memory_settings() -> MemorySettings:
|
||||
"""
|
||||
Get the global MemorySettings instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Returns:
|
||||
MemorySettings instance
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
with _settings_lock:
|
||||
if _settings is None:
|
||||
_settings = MemorySettings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_memory_settings() -> None:
|
||||
"""
|
||||
Reset the global settings instance.
|
||||
|
||||
Primarily used for testing.
|
||||
"""
|
||||
global _settings
|
||||
with _settings_lock:
|
||||
_settings = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_default_settings() -> MemorySettings:
|
||||
"""
|
||||
Get default settings (cached).
|
||||
|
||||
Use this for read-only access to defaults.
|
||||
For mutable access, use get_memory_settings().
|
||||
"""
|
||||
return MemorySettings()
|
||||
10
backend/app/services/memory/consolidation/__init__.py
Normal file
10
backend/app/services/memory/consolidation/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Memory Consolidation
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic
|
||||
- Episodic -> Semantic
|
||||
- Episodic -> Procedural
|
||||
"""
|
||||
|
||||
# Will be populated in #95
|
||||
8
backend/app/services/memory/episodic/__init__.py
Normal file
8
backend/app/services/memory/episodic/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Episodic Memory
|
||||
|
||||
Experiential memory storing past task completions,
|
||||
failures, and learnings.
|
||||
"""
|
||||
|
||||
# Will be populated in #90
|
||||
206
backend/app/services/memory/exceptions.py
Normal file
206
backend/app/services/memory/exceptions.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Memory System Exceptions
|
||||
|
||||
Custom exception classes for the Agent Memory System.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class MemoryError(Exception):
|
||||
"""Base exception for all memory-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
memory_type: str | None = None,
|
||||
scope_type: str | None = None,
|
||||
scope_id: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.memory_type = memory_type
|
||||
self.scope_type = scope_type
|
||||
self.scope_id = scope_id
|
||||
self.details = details or {}
|
||||
|
||||
|
||||
class MemoryNotFoundError(MemoryError):
|
||||
"""Raised when a memory item is not found."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory not found",
|
||||
*,
|
||||
memory_id: UUID | str | None = None,
|
||||
key: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.memory_id = memory_id
|
||||
self.key = key
|
||||
|
||||
|
||||
class MemoryCapacityError(MemoryError):
|
||||
"""Raised when memory capacity limits are exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory capacity exceeded",
|
||||
*,
|
||||
current_size: int = 0,
|
||||
max_size: int = 0,
|
||||
item_count: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.current_size = current_size
|
||||
self.max_size = max_size
|
||||
self.item_count = item_count
|
||||
|
||||
|
||||
class MemoryExpiredError(MemoryError):
|
||||
"""Raised when attempting to access expired memory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory has expired",
|
||||
*,
|
||||
key: str | None = None,
|
||||
expired_at: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.key = key
|
||||
self.expired_at = expired_at
|
||||
|
||||
|
||||
class MemoryStorageError(MemoryError):
|
||||
"""Raised when memory storage operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory storage operation failed",
|
||||
*,
|
||||
operation: str | None = None,
|
||||
backend: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.operation = operation
|
||||
self.backend = backend
|
||||
|
||||
|
||||
class MemorySerializationError(MemoryError):
|
||||
"""Raised when memory serialization/deserialization fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory serialization failed",
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.content_type = content_type
|
||||
|
||||
|
||||
class MemoryScopeError(MemoryError):
|
||||
"""Raised when memory scope operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory scope error",
|
||||
*,
|
||||
requested_scope: str | None = None,
|
||||
allowed_scopes: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.requested_scope = requested_scope
|
||||
self.allowed_scopes = allowed_scopes or []
|
||||
|
||||
|
||||
class MemoryConsolidationError(MemoryError):
|
||||
"""Raised when memory consolidation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory consolidation failed",
|
||||
*,
|
||||
source_type: str | None = None,
|
||||
target_type: str | None = None,
|
||||
items_processed: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.source_type = source_type
|
||||
self.target_type = target_type
|
||||
self.items_processed = items_processed
|
||||
|
||||
|
||||
class MemoryRetrievalError(MemoryError):
|
||||
"""Raised when memory retrieval fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory retrieval failed",
|
||||
*,
|
||||
query: str | None = None,
|
||||
retrieval_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.query = query
|
||||
self.retrieval_type = retrieval_type
|
||||
|
||||
|
||||
class EmbeddingError(MemoryError):
|
||||
"""Raised when embedding generation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Embedding generation failed",
|
||||
*,
|
||||
content_length: int = 0,
|
||||
model: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.content_length = content_length
|
||||
self.model = model
|
||||
|
||||
|
||||
class CheckpointError(MemoryError):
|
||||
"""Raised when checkpoint operations fail."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Checkpoint operation failed",
|
||||
*,
|
||||
checkpoint_id: str | None = None,
|
||||
operation: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.operation = operation
|
||||
|
||||
|
||||
class MemoryConflictError(MemoryError):
|
||||
"""Raised when there's a conflict in memory (e.g., contradictory facts)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Memory conflict detected",
|
||||
*,
|
||||
conflicting_ids: list[str | UUID] | None = None,
|
||||
conflict_type: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(message, **kwargs)
|
||||
self.conflicting_ids = conflicting_ids or []
|
||||
self.conflict_type = conflict_type
|
||||
7
backend/app/services/memory/indexing/__init__.py
Normal file
7
backend/app/services/memory/indexing/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Memory Indexing
|
||||
|
||||
Vector embeddings and retrieval engine for memory search.
|
||||
"""
|
||||
|
||||
# Will be populated in #94
|
||||
606
backend/app/services/memory/manager.py
Normal file
606
backend/app/services/memory/manager.py
Normal file
@@ -0,0 +1,606 @@
|
||||
"""
|
||||
Memory Manager
|
||||
|
||||
Facade for the Agent Memory System providing unified access
|
||||
to all memory types and operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from .config import MemorySettings, get_memory_settings
|
||||
from .types import (
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Fact,
|
||||
FactCreate,
|
||||
MemoryStats,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
ProcedureCreate,
|
||||
RetrievalResult,
|
||||
ScopeContext,
|
||||
ScopeLevel,
|
||||
TaskState,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
"""
|
||||
Unified facade for the Agent Memory System.
|
||||
|
||||
Provides a single entry point for all memory operations across
|
||||
working, episodic, semantic, and procedural memory types.
|
||||
|
||||
Usage:
|
||||
manager = MemoryManager.create()
|
||||
|
||||
# Working memory
|
||||
await manager.set_working("key", {"data": "value"})
|
||||
value = await manager.get_working("key")
|
||||
|
||||
# Episodic memory
|
||||
episode = await manager.record_episode(episode_data)
|
||||
similar = await manager.search_episodes("query")
|
||||
|
||||
# Semantic memory
|
||||
fact = await manager.store_fact(fact_data)
|
||||
facts = await manager.search_facts("query")
|
||||
|
||||
# Procedural memory
|
||||
procedure = await manager.record_procedure(procedure_data)
|
||||
procedures = await manager.find_procedures("context")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: MemorySettings,
|
||||
scope: ScopeContext,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MemoryManager.
|
||||
|
||||
Args:
|
||||
settings: Memory configuration settings
|
||||
scope: The scope context for this manager instance
|
||||
"""
|
||||
self._settings = settings
|
||||
self._scope = scope
|
||||
self._initialized = False
|
||||
|
||||
# These will be initialized when the respective sub-modules are implemented
|
||||
self._working_memory: Any | None = None
|
||||
self._episodic_memory: Any | None = None
|
||||
self._semantic_memory: Any | None = None
|
||||
self._procedural_memory: Any | None = None
|
||||
|
||||
logger.debug(
|
||||
"MemoryManager created for scope %s:%s",
|
||||
scope.scope_type.value,
|
||||
scope.scope_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
scope_type: ScopeLevel = ScopeLevel.SESSION,
|
||||
scope_id: str = "default",
|
||||
parent_scope: ScopeContext | None = None,
|
||||
settings: MemorySettings | None = None,
|
||||
) -> "MemoryManager":
|
||||
"""
|
||||
Create a new MemoryManager instance.
|
||||
|
||||
Args:
|
||||
scope_type: The scope level for this manager
|
||||
scope_id: The scope identifier
|
||||
parent_scope: Optional parent scope for inheritance
|
||||
settings: Optional custom settings (uses global if not provided)
|
||||
|
||||
Returns:
|
||||
A new MemoryManager instance
|
||||
"""
|
||||
if settings is None:
|
||||
settings = get_memory_settings()
|
||||
|
||||
scope = ScopeContext(
|
||||
scope_type=scope_type,
|
||||
scope_id=scope_id,
|
||||
parent=parent_scope,
|
||||
)
|
||||
|
||||
return cls(settings=settings, scope=scope)
|
||||
|
||||
@classmethod
|
||||
def for_session(
|
||||
cls,
|
||||
session_id: str,
|
||||
agent_instance_id: UUID | None = None,
|
||||
project_id: UUID | None = None,
|
||||
) -> "MemoryManager":
|
||||
"""
|
||||
Create a MemoryManager for a specific session.
|
||||
|
||||
Builds the appropriate scope hierarchy based on provided IDs.
|
||||
|
||||
Args:
|
||||
session_id: The session identifier
|
||||
agent_instance_id: Optional agent instance ID
|
||||
project_id: Optional project ID
|
||||
|
||||
Returns:
|
||||
A MemoryManager configured for the session scope
|
||||
"""
|
||||
settings = get_memory_settings()
|
||||
|
||||
# Build scope hierarchy
|
||||
parent: ScopeContext | None = None
|
||||
|
||||
if project_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id=str(project_id),
|
||||
parent=ScopeContext(
|
||||
scope_type=ScopeLevel.GLOBAL,
|
||||
scope_id="global",
|
||||
),
|
||||
)
|
||||
|
||||
if agent_instance_id:
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||
scope_id=str(agent_instance_id),
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
scope = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id=session_id,
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
return cls(settings=settings, scope=scope)
|
||||
|
||||
@property
|
||||
def scope(self) -> ScopeContext:
|
||||
"""Get the current scope context."""
|
||||
return self._scope
|
||||
|
||||
@property
|
||||
def settings(self) -> MemorySettings:
|
||||
"""Get the memory settings."""
|
||||
return self._settings
|
||||
|
||||
# =========================================================================
|
||||
# Working Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def set_working(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl_seconds: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set a value in working memory.
|
||||
|
||||
Args:
|
||||
key: The key to store the value under
|
||||
value: The value to store (must be JSON serializable)
|
||||
ttl_seconds: Optional TTL (uses default if not provided)
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("set_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def get_working(
|
||||
self,
|
||||
key: str,
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Get a value from working memory.
|
||||
|
||||
Args:
|
||||
key: The key to retrieve
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
The stored value or default
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("get_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def delete_working(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a value from working memory.
|
||||
|
||||
Args:
|
||||
key: The key to delete
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if not found
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("delete_working called for key=%s (not yet implemented)", key)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def set_task_state(self, state: TaskState) -> None:
|
||||
"""
|
||||
Set the current task state in working memory.
|
||||
|
||||
Args:
|
||||
state: The task state to store
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug(
|
||||
"set_task_state called for task=%s (not yet implemented)",
|
||||
state.task_id,
|
||||
)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def get_task_state(self) -> TaskState | None:
|
||||
"""
|
||||
Get the current task state from working memory.
|
||||
|
||||
Returns:
|
||||
The current task state or None
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("get_task_state called (not yet implemented)")
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def create_checkpoint(self) -> str:
|
||||
"""
|
||||
Create a checkpoint of the current working memory state.
|
||||
|
||||
Returns:
|
||||
The checkpoint ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug("create_checkpoint called (not yet implemented)")
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
async def restore_checkpoint(self, checkpoint_id: str) -> None:
|
||||
"""
|
||||
Restore working memory from a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The checkpoint to restore from
|
||||
"""
|
||||
# Placeholder - will be implemented in #89
|
||||
logger.debug(
|
||||
"restore_checkpoint called for id=%s (not yet implemented)",
|
||||
checkpoint_id,
|
||||
)
|
||||
raise NotImplementedError("Working memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Episodic Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_episode(self, episode: EpisodeCreate) -> Episode:
|
||||
"""
|
||||
Record a new episode in episodic memory.
|
||||
|
||||
Args:
|
||||
episode: The episode data to record
|
||||
|
||||
Returns:
|
||||
The created episode with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"record_episode called for task=%s (not yet implemented)",
|
||||
episode.task_type,
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def search_episodes(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult[Episode]:
|
||||
"""
|
||||
Search for similar episodes.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
Retrieval result with matching episodes
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"search_episodes called for query=%s (not yet implemented)",
|
||||
query[:50],
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def get_recent_episodes(
|
||||
self,
|
||||
limit: int = 10,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get the most recent episodes.
|
||||
|
||||
Args:
|
||||
limit: Maximum episodes to return
|
||||
|
||||
Returns:
|
||||
List of recent episodes
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug("get_recent_episodes called (not yet implemented)")
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
async def get_episodes_by_outcome(
|
||||
self,
|
||||
outcome: Outcome,
|
||||
limit: int = 10,
|
||||
) -> list[Episode]:
|
||||
"""
|
||||
Get episodes by outcome.
|
||||
|
||||
Args:
|
||||
outcome: The outcome to filter by
|
||||
limit: Maximum episodes to return
|
||||
|
||||
Returns:
|
||||
List of episodes with the specified outcome
|
||||
"""
|
||||
# Placeholder - will be implemented in #90
|
||||
logger.debug(
|
||||
"get_episodes_by_outcome called for outcome=%s (not yet implemented)",
|
||||
outcome.value,
|
||||
)
|
||||
raise NotImplementedError("Episodic memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Semantic Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def store_fact(self, fact: FactCreate) -> Fact:
|
||||
"""
|
||||
Store a new fact in semantic memory.
|
||||
|
||||
Args:
|
||||
fact: The fact data to store
|
||||
|
||||
Returns:
|
||||
The created fact with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"store_fact called for %s %s %s (not yet implemented)",
|
||||
fact.subject,
|
||||
fact.predicate,
|
||||
fact.object,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def search_facts(
|
||||
self,
|
||||
query: str,
|
||||
limit: int | None = None,
|
||||
) -> RetrievalResult[Fact]:
|
||||
"""
|
||||
Search for facts matching a query.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
limit: Maximum results to return
|
||||
|
||||
Returns:
|
||||
Retrieval result with matching facts
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"search_facts called for query=%s (not yet implemented)",
|
||||
query[:50],
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def get_facts_by_entity(
|
||||
self,
|
||||
entity: str,
|
||||
limit: int = 20,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Get facts related to an entity.
|
||||
|
||||
Args:
|
||||
entity: The entity to search for
|
||||
limit: Maximum facts to return
|
||||
|
||||
Returns:
|
||||
List of facts mentioning the entity
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"get_facts_by_entity called for entity=%s (not yet implemented)",
|
||||
entity,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
async def reinforce_fact(self, fact_id: UUID) -> Fact:
|
||||
"""
|
||||
Reinforce a fact (increase confidence from repeated learning).
|
||||
|
||||
Args:
|
||||
fact_id: The fact to reinforce
|
||||
|
||||
Returns:
|
||||
The updated fact
|
||||
"""
|
||||
# Placeholder - will be implemented in #91
|
||||
logger.debug(
|
||||
"reinforce_fact called for id=%s (not yet implemented)",
|
||||
fact_id,
|
||||
)
|
||||
raise NotImplementedError("Semantic memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Procedural Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
|
||||
"""
|
||||
Record a new procedure.
|
||||
|
||||
Args:
|
||||
procedure: The procedure data to record
|
||||
|
||||
Returns:
|
||||
The created procedure with ID
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"record_procedure called for name=%s (not yet implemented)",
|
||||
procedure.name,
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
async def find_procedures(
|
||||
self,
|
||||
context: str,
|
||||
limit: int = 5,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Find procedures matching the current context.
|
||||
|
||||
Args:
|
||||
context: The context to match against
|
||||
limit: Maximum procedures to return
|
||||
|
||||
Returns:
|
||||
List of matching procedures sorted by success rate
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"find_procedures called for context=%s (not yet implemented)",
|
||||
context[:50],
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
async def record_procedure_outcome(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
success: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Record the outcome of using a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: The procedure that was used
|
||||
success: Whether the procedure succeeded
|
||||
"""
|
||||
# Placeholder - will be implemented in #92
|
||||
logger.debug(
|
||||
"record_procedure_outcome called for id=%s success=%s (not yet implemented)",
|
||||
procedure_id,
|
||||
success,
|
||||
)
|
||||
raise NotImplementedError("Procedural memory not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Cross-Memory Operations
|
||||
# =========================================================================
|
||||
|
||||
async def recall(
|
||||
self,
|
||||
query: str,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
limit: int = 10,
|
||||
) -> dict[MemoryType, list[Any]]:
|
||||
"""
|
||||
Recall memories across multiple memory types.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
memory_types: Memory types to search (all if not specified)
|
||||
limit: Maximum results per type
|
||||
|
||||
Returns:
|
||||
Dictionary mapping memory types to results
|
||||
"""
|
||||
# Placeholder - will be implemented in #97 (Component Integration)
|
||||
logger.debug("recall called for query=%s (not yet implemented)", query[:50])
|
||||
raise NotImplementedError("Cross-memory recall not yet implemented")
|
||||
|
||||
async def get_stats(
|
||||
self,
|
||||
memory_type: MemoryType | None = None,
|
||||
) -> list[MemoryStats]:
|
||||
"""
|
||||
Get memory statistics.
|
||||
|
||||
Args:
|
||||
memory_type: Specific type or all if not specified
|
||||
|
||||
Returns:
|
||||
List of statistics for requested memory types
|
||||
"""
|
||||
# Placeholder - will be implemented in #100 (Metrics & Observability)
|
||||
logger.debug("get_stats called (not yet implemented)")
|
||||
raise NotImplementedError("Memory stats not yet implemented")
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle Operations
|
||||
# =========================================================================
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Initialize the memory manager and its backends.
|
||||
|
||||
Should be called before using the manager.
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.debug("MemoryManager already initialized")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Initializing MemoryManager for scope %s:%s",
|
||||
self._scope.scope_type.value,
|
||||
self._scope.scope_id,
|
||||
)
|
||||
|
||||
# TODO: Initialize backends when implemented
|
||||
|
||||
self._initialized = True
|
||||
logger.info("MemoryManager initialized successfully")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close the memory manager and release resources.
|
||||
|
||||
Should be called when done using the manager.
|
||||
"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Closing MemoryManager for scope %s:%s",
|
||||
self._scope.scope_type.value,
|
||||
self._scope.scope_id,
|
||||
)
|
||||
|
||||
# TODO: Close backends when implemented
|
||||
|
||||
self._initialized = False
|
||||
logger.info("MemoryManager closed successfully")
|
||||
|
||||
async def __aenter__(self) -> "MemoryManager":
|
||||
"""Async context manager entry."""
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Async context manager exit."""
|
||||
await self.close()
|
||||
7
backend/app/services/memory/procedural/__init__.py
Normal file
7
backend/app/services/memory/procedural/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Procedural Memory
|
||||
|
||||
Learned skills and procedures from successful task patterns.
|
||||
"""
|
||||
|
||||
# Will be populated in #92
|
||||
8
backend/app/services/memory/scoping/__init__.py
Normal file
8
backend/app/services/memory/scoping/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Memory Scoping
|
||||
|
||||
Hierarchical scoping for memory with inheritance:
|
||||
Global -> Project -> Agent Type -> Agent Instance -> Session
|
||||
"""
|
||||
|
||||
# Will be populated in #93
|
||||
8
backend/app/services/memory/semantic/__init__.py
Normal file
8
backend/app/services/memory/semantic/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Semantic Memory
|
||||
|
||||
Fact storage with triple format (subject, predicate, object)
|
||||
and semantic search capabilities.
|
||||
"""
|
||||
|
||||
# Will be populated in #91
|
||||
322
backend/app/services/memory/types.py
Normal file
322
backend/app/services/memory/types.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
Memory System Types
|
||||
|
||||
Core type definitions and interfaces for the Agent Memory System.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
"""Types of memory in the agent memory system."""
|
||||
|
||||
WORKING = "working"
|
||||
EPISODIC = "episodic"
|
||||
SEMANTIC = "semantic"
|
||||
PROCEDURAL = "procedural"
|
||||
|
||||
|
||||
class ScopeLevel(str, Enum):
|
||||
"""Hierarchical scoping levels for memory."""
|
||||
|
||||
GLOBAL = "global"
|
||||
PROJECT = "project"
|
||||
AGENT_TYPE = "agent_type"
|
||||
AGENT_INSTANCE = "agent_instance"
|
||||
SESSION = "session"
|
||||
|
||||
|
||||
class Outcome(str, Enum):
|
||||
"""Outcome of a task or episode."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
PARTIAL = "partial"
|
||||
|
||||
|
||||
class ConsolidationStatus(str, Enum):
|
||||
"""Status of a memory consolidation job."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ConsolidationType(str, Enum):
|
||||
"""Types of memory consolidation."""
|
||||
|
||||
WORKING_TO_EPISODIC = "working_to_episodic"
|
||||
EPISODIC_TO_SEMANTIC = "episodic_to_semantic"
|
||||
EPISODIC_TO_PROCEDURAL = "episodic_to_procedural"
|
||||
PRUNING = "pruning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScopeContext:
|
||||
"""Represents a memory scope with its hierarchy."""
|
||||
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
parent: "ScopeContext | None" = None
|
||||
|
||||
def get_hierarchy(self) -> list["ScopeContext"]:
|
||||
"""Get the full scope hierarchy from root to this scope."""
|
||||
hierarchy: list[ScopeContext] = []
|
||||
current: ScopeContext | None = self
|
||||
while current is not None:
|
||||
hierarchy.insert(0, current)
|
||||
current = current.parent
|
||||
return hierarchy
|
||||
|
||||
def to_key_prefix(self) -> str:
|
||||
"""Convert scope to a key prefix for storage."""
|
||||
return f"{self.scope_type.value}:{self.scope_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryItem:
|
||||
"""Base class for all memory items."""
|
||||
|
||||
id: UUID
|
||||
memory_type: MemoryType
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def get_age_seconds(self) -> float:
|
||||
"""Get the age of this memory item in seconds."""
|
||||
return (datetime.now() - self.created_at).total_seconds()
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkingMemoryItem:
|
||||
"""A key-value item in working memory."""
|
||||
|
||||
id: UUID
|
||||
scope_type: ScopeLevel
|
||||
scope_id: str
|
||||
key: str
|
||||
value: Any
|
||||
expires_at: datetime | None = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this item has expired."""
|
||||
if self.expires_at is None:
|
||||
return False
|
||||
return datetime.now() > self.expires_at
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskState:
|
||||
"""Current state of a task in working memory."""
|
||||
|
||||
task_id: str
|
||||
task_type: str
|
||||
description: str
|
||||
status: str = "in_progress"
|
||||
current_step: int = 0
|
||||
total_steps: int = 0
|
||||
progress_percent: float = 0.0
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
started_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Episode:
|
||||
"""An episodic memory - a recorded experience."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
agent_instance_id: UUID | None
|
||||
agent_type_id: UUID | None
|
||||
session_id: str
|
||||
task_type: str
|
||||
task_description: str
|
||||
actions: list[dict[str, Any]]
|
||||
context_summary: str
|
||||
outcome: Outcome
|
||||
outcome_details: str
|
||||
duration_seconds: float
|
||||
tokens_used: int
|
||||
lessons_learned: list[str]
|
||||
importance_score: float
|
||||
embedding: list[float] | None
|
||||
occurred_at: datetime
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class EpisodeCreate:
|
||||
"""Data required to create a new episode."""
|
||||
|
||||
project_id: UUID
|
||||
session_id: str
|
||||
task_type: str
|
||||
task_description: str
|
||||
actions: list[dict[str, Any]]
|
||||
context_summary: str
|
||||
outcome: Outcome
|
||||
outcome_details: str
|
||||
duration_seconds: float
|
||||
tokens_used: int
|
||||
lessons_learned: list[str] = field(default_factory=list)
|
||||
importance_score: float = 0.5
|
||||
agent_instance_id: UUID | None = None
|
||||
agent_type_id: UUID | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fact:
|
||||
"""A semantic memory fact - a piece of knowledge."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID | None # None for global facts
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float
|
||||
source_episode_ids: list[UUID]
|
||||
first_learned: datetime
|
||||
last_reinforced: datetime
|
||||
reinforcement_count: int
|
||||
embedding: list[float] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactCreate:
|
||||
"""Data required to create a new fact."""
|
||||
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float = 0.8
|
||||
project_id: UUID | None = None
|
||||
source_episode_ids: list[UUID] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Procedure:
|
||||
"""A procedural memory - a learned skill or procedure."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID | None
|
||||
agent_type_id: UUID | None
|
||||
name: str
|
||||
trigger_pattern: str
|
||||
steps: list[dict[str, Any]]
|
||||
success_count: int
|
||||
failure_count: int
|
||||
last_used: datetime | None
|
||||
embedding: list[float] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate the success rate of this procedure."""
|
||||
total = self.success_count + self.failure_count
|
||||
if total == 0:
|
||||
return 0.0
|
||||
return self.success_count / total
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcedureCreate:
|
||||
"""Data required to create a new procedure."""
|
||||
|
||||
name: str
|
||||
trigger_pattern: str
|
||||
steps: list[dict[str, Any]]
|
||||
project_id: UUID | None = None
|
||||
agent_type_id: UUID | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Step:
|
||||
"""A single step in a procedure."""
|
||||
|
||||
order: int
|
||||
action: str
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
expected_outcome: str = ""
|
||||
fallback_action: str | None = None
|
||||
|
||||
|
||||
class MemoryStore[T: MemoryItem](ABC):
|
||||
"""Abstract base class for memory storage backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def store(self, item: T) -> T:
|
||||
"""Store a memory item."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get(self, item_id: UUID) -> T | None:
|
||||
"""Get a memory item by ID."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, item_id: UUID) -> bool:
|
||||
"""Delete a memory item."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list(
|
||||
self,
|
||||
scope_type: ScopeLevel | None = None,
|
||||
scope_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[T]:
|
||||
"""List memory items with optional scope filtering."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count(
|
||||
self,
|
||||
scope_type: ScopeLevel | None = None,
|
||||
scope_id: str | None = None,
|
||||
) -> int:
|
||||
"""Count memory items with optional scope filtering."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult[T]:
|
||||
"""Result of a memory retrieval operation."""
|
||||
|
||||
items: list[T]
|
||||
total_count: int
|
||||
query: str
|
||||
retrieval_type: str
|
||||
latency_ms: float
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryStats:
|
||||
"""Statistics about memory usage."""
|
||||
|
||||
memory_type: MemoryType
|
||||
scope_type: ScopeLevel | None
|
||||
scope_id: str | None
|
||||
item_count: int
|
||||
total_size_bytes: int
|
||||
oldest_item_age_seconds: float
|
||||
newest_item_age_seconds: float
|
||||
avg_item_size_bytes: float
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
8
backend/app/services/memory/working/__init__.py
Normal file
8
backend/app/services/memory/working/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Working Memory
|
||||
|
||||
Session-scoped ephemeral memory for current task state,
|
||||
variables, and scratchpad.
|
||||
"""
|
||||
|
||||
# Will be populated in #89
|
||||
1
backend/tests/unit/services/memory/__init__.py
Normal file
1
backend/tests/unit/services/memory/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for the Agent Memory System."""
|
||||
243
backend/tests/unit/services/memory/test_config.py
Normal file
243
backend/tests/unit/services/memory/test_config.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""
|
||||
Tests for Memory System Configuration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.services.memory.config import (
|
||||
MemorySettings,
|
||||
get_default_settings,
|
||||
get_memory_settings,
|
||||
reset_memory_settings,
|
||||
)
|
||||
|
||||
|
||||
class TestMemorySettings:
|
||||
"""Tests for MemorySettings class."""
|
||||
|
||||
def test_default_settings(self) -> None:
|
||||
"""Test that default settings are valid."""
|
||||
settings = MemorySettings()
|
||||
|
||||
# Working memory defaults
|
||||
assert settings.working_memory_backend == "redis"
|
||||
assert settings.working_memory_default_ttl_seconds == 3600
|
||||
assert settings.working_memory_max_items_per_session == 1000
|
||||
|
||||
# Redis defaults
|
||||
assert settings.redis_url == "redis://localhost:6379/0"
|
||||
assert settings.redis_prefix == "mem"
|
||||
|
||||
# Episodic defaults
|
||||
assert settings.episodic_max_episodes_per_project == 10000
|
||||
assert settings.episodic_default_importance == 0.5
|
||||
|
||||
# Semantic defaults
|
||||
assert settings.semantic_max_facts_per_project == 50000
|
||||
assert settings.semantic_min_confidence == 0.1
|
||||
|
||||
# Procedural defaults
|
||||
assert settings.procedural_max_procedures_per_project == 1000
|
||||
assert settings.procedural_min_success_rate == 0.3
|
||||
|
||||
# Embedding defaults
|
||||
assert settings.embedding_model == "text-embedding-3-small"
|
||||
assert settings.embedding_dimensions == 1536
|
||||
|
||||
# Retrieval defaults
|
||||
assert settings.retrieval_default_limit == 10
|
||||
assert settings.retrieval_max_limit == 100
|
||||
|
||||
def test_invalid_backend(self) -> None:
|
||||
"""Test that invalid backend raises error."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MemorySettings(working_memory_backend="invalid")
|
||||
|
||||
assert "backend must be one of" in str(exc_info.value)
|
||||
|
||||
def test_valid_backends(self) -> None:
|
||||
"""Test valid backend values."""
|
||||
redis_settings = MemorySettings(working_memory_backend="redis")
|
||||
assert redis_settings.working_memory_backend == "redis"
|
||||
|
||||
memory_settings = MemorySettings(working_memory_backend="memory")
|
||||
assert memory_settings.working_memory_backend == "memory"
|
||||
|
||||
def test_invalid_embedding_model(self) -> None:
|
||||
"""Test that invalid embedding model raises error."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MemorySettings(embedding_model="invalid-model")
|
||||
|
||||
assert "embedding_model must be one of" in str(exc_info.value)
|
||||
|
||||
def test_valid_embedding_models(self) -> None:
|
||||
"""Test valid embedding model values."""
|
||||
for model in [
|
||||
"text-embedding-3-small",
|
||||
"text-embedding-3-large",
|
||||
"text-embedding-ada-002",
|
||||
]:
|
||||
settings = MemorySettings(embedding_model=model)
|
||||
assert settings.embedding_model == model
|
||||
|
||||
def test_retrieval_limit_validation(self) -> None:
|
||||
"""Test that default limit cannot exceed max limit."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
MemorySettings(
|
||||
retrieval_default_limit=50,
|
||||
retrieval_max_limit=25,
|
||||
)
|
||||
|
||||
assert "cannot exceed retrieval_max_limit" in str(exc_info.value)
|
||||
|
||||
def test_valid_retrieval_limits(self) -> None:
|
||||
"""Test valid retrieval limit combinations."""
|
||||
settings = MemorySettings(
|
||||
retrieval_default_limit=10,
|
||||
retrieval_max_limit=50,
|
||||
)
|
||||
assert settings.retrieval_default_limit == 10
|
||||
assert settings.retrieval_max_limit == 50
|
||||
|
||||
# Equal limits should be valid
|
||||
settings = MemorySettings(
|
||||
retrieval_default_limit=25,
|
||||
retrieval_max_limit=25,
|
||||
)
|
||||
assert settings.retrieval_default_limit == 25
|
||||
assert settings.retrieval_max_limit == 25
|
||||
|
||||
def test_ttl_bounds(self) -> None:
|
||||
"""Test TTL setting bounds."""
|
||||
# Valid TTL
|
||||
settings = MemorySettings(working_memory_default_ttl_seconds=1800)
|
||||
assert settings.working_memory_default_ttl_seconds == 1800
|
||||
|
||||
# Too low
|
||||
with pytest.raises(ValidationError):
|
||||
MemorySettings(working_memory_default_ttl_seconds=30)
|
||||
|
||||
# Too high
|
||||
with pytest.raises(ValidationError):
|
||||
MemorySettings(working_memory_default_ttl_seconds=100000)
|
||||
|
||||
def test_confidence_bounds(self) -> None:
|
||||
"""Test confidence score bounds."""
|
||||
# Valid confidence
|
||||
settings = MemorySettings(semantic_min_confidence=0.5)
|
||||
assert settings.semantic_min_confidence == 0.5
|
||||
|
||||
# Bounds
|
||||
settings = MemorySettings(semantic_min_confidence=0.0)
|
||||
assert settings.semantic_min_confidence == 0.0
|
||||
|
||||
settings = MemorySettings(semantic_min_confidence=1.0)
|
||||
assert settings.semantic_min_confidence == 1.0
|
||||
|
||||
# Out of bounds
|
||||
with pytest.raises(ValidationError):
|
||||
MemorySettings(semantic_min_confidence=-0.1)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
MemorySettings(semantic_min_confidence=1.1)
|
||||
|
||||
def test_get_working_memory_config(self) -> None:
|
||||
"""Test working memory config dictionary."""
|
||||
settings = MemorySettings()
|
||||
config = settings.get_working_memory_config()
|
||||
|
||||
assert config["backend"] == "redis"
|
||||
assert config["default_ttl_seconds"] == 3600
|
||||
assert config["max_items_per_session"] == 1000
|
||||
assert config["max_value_size_bytes"] == 1048576
|
||||
assert config["checkpoint_enabled"] is True
|
||||
|
||||
def test_get_redis_config(self) -> None:
|
||||
"""Test Redis config dictionary."""
|
||||
settings = MemorySettings()
|
||||
config = settings.get_redis_config()
|
||||
|
||||
assert config["url"] == "redis://localhost:6379/0"
|
||||
assert config["prefix"] == "mem"
|
||||
assert config["connection_timeout_seconds"] == 5
|
||||
|
||||
def test_get_embedding_config(self) -> None:
|
||||
"""Test embedding config dictionary."""
|
||||
settings = MemorySettings()
|
||||
config = settings.get_embedding_config()
|
||||
|
||||
assert config["model"] == "text-embedding-3-small"
|
||||
assert config["dimensions"] == 1536
|
||||
assert config["batch_size"] == 100
|
||||
assert config["cache_enabled"] is True
|
||||
|
||||
def test_get_consolidation_config(self) -> None:
|
||||
"""Test consolidation config dictionary."""
|
||||
settings = MemorySettings()
|
||||
config = settings.get_consolidation_config()
|
||||
|
||||
assert config["enabled"] is True
|
||||
assert config["batch_size"] == 100
|
||||
assert config["schedule_cron"] == "0 3 * * *"
|
||||
assert config["working_to_episodic_delay_minutes"] == 30
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test full settings to dictionary."""
|
||||
settings = MemorySettings()
|
||||
config = settings.to_dict()
|
||||
|
||||
assert "working_memory" in config
|
||||
assert "redis" in config
|
||||
assert "episodic" in config
|
||||
assert "semantic" in config
|
||||
assert "procedural" in config
|
||||
assert "embedding" in config
|
||||
assert "retrieval" in config
|
||||
assert "consolidation" in config
|
||||
assert "pruning" in config
|
||||
assert "cache" in config
|
||||
assert "performance" in config
|
||||
|
||||
|
||||
class TestMemorySettingsSingleton:
|
||||
"""Tests for MemorySettings singleton functions."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Reset singleton before each test."""
|
||||
reset_memory_settings()
|
||||
|
||||
def teardown_method(self) -> None:
|
||||
"""Reset singleton after each test."""
|
||||
reset_memory_settings()
|
||||
|
||||
def test_get_memory_settings_singleton(self) -> None:
|
||||
"""Test that get_memory_settings returns same instance."""
|
||||
settings1 = get_memory_settings()
|
||||
settings2 = get_memory_settings()
|
||||
|
||||
assert settings1 is settings2
|
||||
|
||||
def test_reset_memory_settings(self) -> None:
|
||||
"""Test that reset creates new instance."""
|
||||
settings1 = get_memory_settings()
|
||||
reset_memory_settings()
|
||||
settings2 = get_memory_settings()
|
||||
|
||||
assert settings1 is not settings2
|
||||
|
||||
def test_get_default_settings_cached(self) -> None:
|
||||
"""Test that get_default_settings is cached."""
|
||||
# Clear the lru_cache first
|
||||
get_default_settings.cache_clear()
|
||||
|
||||
settings1 = get_default_settings()
|
||||
settings2 = get_default_settings()
|
||||
|
||||
assert settings1 is settings2
|
||||
|
||||
def test_default_settings_immutable_pattern(self) -> None:
|
||||
"""Test that default settings provide consistent values."""
|
||||
defaults = get_default_settings()
|
||||
assert defaults.working_memory_backend == "redis"
|
||||
assert defaults.embedding_model == "text-embedding-3-small"
|
||||
325
backend/tests/unit/services/memory/test_exceptions.py
Normal file
325
backend/tests/unit/services/memory/test_exceptions.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
Tests for Memory System Exceptions.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.exceptions import (
|
||||
CheckpointError,
|
||||
EmbeddingError,
|
||||
MemoryCapacityError,
|
||||
MemoryConflictError,
|
||||
MemoryConsolidationError,
|
||||
MemoryError,
|
||||
MemoryExpiredError,
|
||||
MemoryNotFoundError,
|
||||
MemoryRetrievalError,
|
||||
MemoryScopeError,
|
||||
MemorySerializationError,
|
||||
MemoryStorageError,
|
||||
)
|
||||
|
||||
|
||||
class TestMemoryError:
|
||||
"""Tests for base MemoryError class."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test creating a basic memory error."""
|
||||
error = MemoryError("Something went wrong")
|
||||
|
||||
assert str(error) == "Something went wrong"
|
||||
assert error.message == "Something went wrong"
|
||||
assert error.memory_type is None
|
||||
assert error.scope_type is None
|
||||
assert error.scope_id is None
|
||||
assert error.details == {}
|
||||
|
||||
def test_error_with_context(self) -> None:
|
||||
"""Test creating an error with context."""
|
||||
error = MemoryError(
|
||||
"Operation failed",
|
||||
memory_type="episodic",
|
||||
scope_type="project",
|
||||
scope_id="proj-123",
|
||||
details={"operation": "search"},
|
||||
)
|
||||
|
||||
assert error.memory_type == "episodic"
|
||||
assert error.scope_type == "project"
|
||||
assert error.scope_id == "proj-123"
|
||||
assert error.details == {"operation": "search"}
|
||||
|
||||
def test_error_inheritance(self) -> None:
|
||||
"""Test that MemoryError inherits from Exception."""
|
||||
error = MemoryError("test")
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
class TestMemoryNotFoundError:
|
||||
"""Tests for MemoryNotFoundError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = MemoryNotFoundError()
|
||||
assert error.message == "Memory not found"
|
||||
|
||||
def test_with_memory_id(self) -> None:
|
||||
"""Test error with memory ID."""
|
||||
memory_id = uuid4()
|
||||
error = MemoryNotFoundError(
|
||||
f"Memory {memory_id} not found",
|
||||
memory_id=memory_id,
|
||||
)
|
||||
|
||||
assert error.memory_id == memory_id
|
||||
|
||||
def test_with_key(self) -> None:
|
||||
"""Test error with key."""
|
||||
error = MemoryNotFoundError(
|
||||
"Key not found",
|
||||
key="my_key",
|
||||
)
|
||||
|
||||
assert error.key == "my_key"
|
||||
|
||||
|
||||
class TestMemoryCapacityError:
|
||||
"""Tests for MemoryCapacityError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = MemoryCapacityError()
|
||||
assert error.message == "Memory capacity exceeded"
|
||||
|
||||
def test_with_sizes(self) -> None:
|
||||
"""Test error with size information."""
|
||||
error = MemoryCapacityError(
|
||||
"Working memory full",
|
||||
current_size=1048576,
|
||||
max_size=1000000,
|
||||
item_count=500,
|
||||
)
|
||||
|
||||
assert error.current_size == 1048576
|
||||
assert error.max_size == 1000000
|
||||
assert error.item_count == 500
|
||||
|
||||
|
||||
class TestMemoryExpiredError:
|
||||
"""Tests for MemoryExpiredError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = MemoryExpiredError()
|
||||
assert error.message == "Memory has expired"
|
||||
|
||||
def test_with_expiry_info(self) -> None:
|
||||
"""Test error with expiry information."""
|
||||
error = MemoryExpiredError(
|
||||
"Key expired",
|
||||
key="session_data",
|
||||
expired_at="2025-01-05T00:00:00Z",
|
||||
)
|
||||
|
||||
assert error.key == "session_data"
|
||||
assert error.expired_at == "2025-01-05T00:00:00Z"
|
||||
|
||||
|
||||
class TestMemoryStorageError:
|
||||
"""Tests for MemoryStorageError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = MemoryStorageError()
|
||||
assert error.message == "Memory storage operation failed"
|
||||
|
||||
def test_with_operation_info(self) -> None:
|
||||
"""Test error with operation information."""
|
||||
error = MemoryStorageError(
|
||||
"Redis write failed",
|
||||
operation="set",
|
||||
backend="redis",
|
||||
)
|
||||
|
||||
assert error.operation == "set"
|
||||
assert error.backend == "redis"
|
||||
|
||||
|
||||
class TestMemorySerializationError:
|
||||
"""Tests for MemorySerializationError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = MemorySerializationError()
|
||||
assert error.message == "Memory serialization failed"
|
||||
|
||||
def test_with_content_type(self) -> None:
|
||||
"""Test error with content type."""
|
||||
error = MemorySerializationError(
|
||||
"Cannot serialize function",
|
||||
content_type="function",
|
||||
)
|
||||
|
||||
assert error.content_type == "function"
|
||||
|
||||
|
||||
class TestMemoryScopeError:
|
||||
"""Tests for MemoryScopeError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = MemoryScopeError()
|
||||
assert error.message == "Memory scope error"
|
||||
|
||||
def test_with_scope_info(self) -> None:
|
||||
"""Test error with scope information."""
|
||||
error = MemoryScopeError(
|
||||
"Scope access denied",
|
||||
requested_scope="global",
|
||||
allowed_scopes=["project", "session"],
|
||||
)
|
||||
|
||||
assert error.requested_scope == "global"
|
||||
assert error.allowed_scopes == ["project", "session"]
|
||||
|
||||
|
||||
class TestMemoryConsolidationError:
|
||||
"""Tests for MemoryConsolidationError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = MemoryConsolidationError()
|
||||
assert error.message == "Memory consolidation failed"
|
||||
|
||||
def test_with_consolidation_info(self) -> None:
|
||||
"""Test error with consolidation information."""
|
||||
error = MemoryConsolidationError(
|
||||
"Transfer failed",
|
||||
source_type="working",
|
||||
target_type="episodic",
|
||||
items_processed=50,
|
||||
)
|
||||
|
||||
assert error.source_type == "working"
|
||||
assert error.target_type == "episodic"
|
||||
assert error.items_processed == 50
|
||||
|
||||
|
||||
class TestMemoryRetrievalError:
|
||||
"""Tests for MemoryRetrievalError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = MemoryRetrievalError()
|
||||
assert error.message == "Memory retrieval failed"
|
||||
|
||||
def test_with_query_info(self) -> None:
|
||||
"""Test error with query information."""
|
||||
error = MemoryRetrievalError(
|
||||
"Search timeout",
|
||||
query="complex search query",
|
||||
retrieval_type="semantic",
|
||||
)
|
||||
|
||||
assert error.query == "complex search query"
|
||||
assert error.retrieval_type == "semantic"
|
||||
|
||||
|
||||
class TestEmbeddingError:
|
||||
"""Tests for EmbeddingError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = EmbeddingError()
|
||||
assert error.message == "Embedding generation failed"
|
||||
|
||||
def test_with_embedding_info(self) -> None:
|
||||
"""Test error with embedding information."""
|
||||
error = EmbeddingError(
|
||||
"Content too long",
|
||||
content_length=100000,
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
|
||||
assert error.content_length == 100000
|
||||
assert error.model == "text-embedding-3-small"
|
||||
|
||||
|
||||
class TestCheckpointError:
|
||||
"""Tests for CheckpointError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = CheckpointError()
|
||||
assert error.message == "Checkpoint operation failed"
|
||||
|
||||
def test_with_checkpoint_info(self) -> None:
|
||||
"""Test error with checkpoint information."""
|
||||
error = CheckpointError(
|
||||
"Restore failed",
|
||||
checkpoint_id="chk-123",
|
||||
operation="restore",
|
||||
)
|
||||
|
||||
assert error.checkpoint_id == "chk-123"
|
||||
assert error.operation == "restore"
|
||||
|
||||
|
||||
class TestMemoryConflictError:
|
||||
"""Tests for MemoryConflictError class."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = MemoryConflictError()
|
||||
assert error.message == "Memory conflict detected"
|
||||
|
||||
def test_with_conflict_info(self) -> None:
|
||||
"""Test error with conflict information."""
|
||||
id1 = uuid4()
|
||||
id2 = uuid4()
|
||||
error = MemoryConflictError(
|
||||
"Contradictory facts detected",
|
||||
conflicting_ids=[id1, id2],
|
||||
conflict_type="semantic",
|
||||
)
|
||||
|
||||
assert len(error.conflicting_ids) == 2
|
||||
assert error.conflict_type == "semantic"
|
||||
|
||||
|
||||
class TestExceptionHierarchy:
|
||||
"""Tests for exception inheritance hierarchy."""
|
||||
|
||||
def test_all_exceptions_inherit_from_memory_error(self) -> None:
|
||||
"""Test that all exceptions inherit from MemoryError."""
|
||||
exceptions = [
|
||||
MemoryNotFoundError(),
|
||||
MemoryCapacityError(),
|
||||
MemoryExpiredError(),
|
||||
MemoryStorageError(),
|
||||
MemorySerializationError(),
|
||||
MemoryScopeError(),
|
||||
MemoryConsolidationError(),
|
||||
MemoryRetrievalError(),
|
||||
EmbeddingError(),
|
||||
CheckpointError(),
|
||||
MemoryConflictError(),
|
||||
]
|
||||
|
||||
for exc in exceptions:
|
||||
assert isinstance(exc, MemoryError)
|
||||
assert isinstance(exc, Exception)
|
||||
|
||||
def test_can_catch_base_error(self) -> None:
|
||||
"""Test that catching MemoryError catches all subclasses."""
|
||||
exceptions = [
|
||||
MemoryNotFoundError("not found"),
|
||||
MemoryCapacityError("capacity"),
|
||||
MemoryStorageError("storage"),
|
||||
]
|
||||
|
||||
for exc in exceptions:
|
||||
with pytest.raises(MemoryError):
|
||||
raise exc
|
||||
411
backend/tests/unit/services/memory/test_types.py
Normal file
411
backend/tests/unit/services/memory/test_types.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Tests for Memory System Types.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
from app.services.memory.types import (
|
||||
ConsolidationStatus,
|
||||
ConsolidationType,
|
||||
EpisodeCreate,
|
||||
Fact,
|
||||
FactCreate,
|
||||
MemoryItem,
|
||||
MemoryStats,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
ProcedureCreate,
|
||||
RetrievalResult,
|
||||
ScopeContext,
|
||||
ScopeLevel,
|
||||
Step,
|
||||
TaskState,
|
||||
WorkingMemoryItem,
|
||||
)
|
||||
|
||||
|
||||
class TestEnums:
|
||||
"""Tests for memory enums."""
|
||||
|
||||
def test_memory_type_values(self) -> None:
|
||||
"""Test MemoryType enum values."""
|
||||
assert MemoryType.WORKING == "working"
|
||||
assert MemoryType.EPISODIC == "episodic"
|
||||
assert MemoryType.SEMANTIC == "semantic"
|
||||
assert MemoryType.PROCEDURAL == "procedural"
|
||||
|
||||
def test_scope_level_values(self) -> None:
|
||||
"""Test ScopeLevel enum values."""
|
||||
assert ScopeLevel.GLOBAL == "global"
|
||||
assert ScopeLevel.PROJECT == "project"
|
||||
assert ScopeLevel.AGENT_TYPE == "agent_type"
|
||||
assert ScopeLevel.AGENT_INSTANCE == "agent_instance"
|
||||
assert ScopeLevel.SESSION == "session"
|
||||
|
||||
def test_outcome_values(self) -> None:
|
||||
"""Test Outcome enum values."""
|
||||
assert Outcome.SUCCESS == "success"
|
||||
assert Outcome.FAILURE == "failure"
|
||||
assert Outcome.PARTIAL == "partial"
|
||||
|
||||
def test_consolidation_status_values(self) -> None:
|
||||
"""Test ConsolidationStatus enum values."""
|
||||
assert ConsolidationStatus.PENDING == "pending"
|
||||
assert ConsolidationStatus.RUNNING == "running"
|
||||
assert ConsolidationStatus.COMPLETED == "completed"
|
||||
assert ConsolidationStatus.FAILED == "failed"
|
||||
|
||||
def test_consolidation_type_values(self) -> None:
|
||||
"""Test ConsolidationType enum values."""
|
||||
assert ConsolidationType.WORKING_TO_EPISODIC == "working_to_episodic"
|
||||
assert ConsolidationType.EPISODIC_TO_SEMANTIC == "episodic_to_semantic"
|
||||
assert ConsolidationType.EPISODIC_TO_PROCEDURAL == "episodic_to_procedural"
|
||||
assert ConsolidationType.PRUNING == "pruning"
|
||||
|
||||
|
||||
class TestScopeContext:
|
||||
"""Tests for ScopeContext dataclass."""
|
||||
|
||||
def test_create_scope_context(self) -> None:
|
||||
"""Test creating a scope context."""
|
||||
scope = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id="sess-123",
|
||||
)
|
||||
|
||||
assert scope.scope_type == ScopeLevel.SESSION
|
||||
assert scope.scope_id == "sess-123"
|
||||
assert scope.parent is None
|
||||
|
||||
def test_scope_with_parent(self) -> None:
|
||||
"""Test creating a scope with parent."""
|
||||
parent = ScopeContext(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id="proj-123",
|
||||
)
|
||||
child = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id="sess-456",
|
||||
parent=parent,
|
||||
)
|
||||
|
||||
assert child.parent is parent
|
||||
assert child.parent.scope_type == ScopeLevel.PROJECT
|
||||
|
||||
def test_get_hierarchy(self) -> None:
|
||||
"""Test getting scope hierarchy."""
|
||||
global_scope = ScopeContext(
|
||||
scope_type=ScopeLevel.GLOBAL,
|
||||
scope_id="global",
|
||||
)
|
||||
project_scope = ScopeContext(
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id="proj-123",
|
||||
parent=global_scope,
|
||||
)
|
||||
session_scope = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id="sess-456",
|
||||
parent=project_scope,
|
||||
)
|
||||
|
||||
hierarchy = session_scope.get_hierarchy()
|
||||
|
||||
assert len(hierarchy) == 3
|
||||
assert hierarchy[0].scope_type == ScopeLevel.GLOBAL
|
||||
assert hierarchy[1].scope_type == ScopeLevel.PROJECT
|
||||
assert hierarchy[2].scope_type == ScopeLevel.SESSION
|
||||
|
||||
def test_to_key_prefix(self) -> None:
|
||||
"""Test converting scope to key prefix."""
|
||||
scope = ScopeContext(
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id="sess-123",
|
||||
)
|
||||
|
||||
prefix = scope.to_key_prefix()
|
||||
assert prefix == "session:sess-123"
|
||||
|
||||
|
||||
class TestMemoryItem:
|
||||
"""Tests for MemoryItem dataclass."""
|
||||
|
||||
def test_create_memory_item(self) -> None:
|
||||
"""Test creating a memory item."""
|
||||
now = datetime.now()
|
||||
item = MemoryItem(
|
||||
id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id="proj-123",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
assert item.memory_type == MemoryType.EPISODIC
|
||||
assert item.scope_type == ScopeLevel.PROJECT
|
||||
assert item.metadata == {}
|
||||
|
||||
def test_get_age_seconds(self) -> None:
|
||||
"""Test getting item age."""
|
||||
past = datetime.now() - timedelta(seconds=100)
|
||||
item = MemoryItem(
|
||||
id=uuid4(),
|
||||
memory_type=MemoryType.SEMANTIC,
|
||||
scope_type=ScopeLevel.GLOBAL,
|
||||
scope_id="global",
|
||||
created_at=past,
|
||||
updated_at=past,
|
||||
)
|
||||
|
||||
age = item.get_age_seconds()
|
||||
assert age >= 100
|
||||
assert age < 105 # Allow small margin
|
||||
|
||||
|
||||
class TestWorkingMemoryItem:
|
||||
"""Tests for WorkingMemoryItem dataclass."""
|
||||
|
||||
def test_create_working_memory_item(self) -> None:
|
||||
"""Test creating a working memory item."""
|
||||
item = WorkingMemoryItem(
|
||||
id=uuid4(),
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id="sess-123",
|
||||
key="my_key",
|
||||
value={"data": "value"},
|
||||
)
|
||||
|
||||
assert item.key == "my_key"
|
||||
assert item.value == {"data": "value"}
|
||||
assert item.expires_at is None
|
||||
|
||||
def test_is_expired_no_expiry(self) -> None:
|
||||
"""Test is_expired with no expiry set."""
|
||||
item = WorkingMemoryItem(
|
||||
id=uuid4(),
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id="sess-123",
|
||||
key="my_key",
|
||||
value="value",
|
||||
)
|
||||
|
||||
assert item.is_expired() is False
|
||||
|
||||
def test_is_expired_future(self) -> None:
|
||||
"""Test is_expired with future expiry."""
|
||||
item = WorkingMemoryItem(
|
||||
id=uuid4(),
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id="sess-123",
|
||||
key="my_key",
|
||||
value="value",
|
||||
expires_at=datetime.now() + timedelta(hours=1),
|
||||
)
|
||||
|
||||
assert item.is_expired() is False
|
||||
|
||||
def test_is_expired_past(self) -> None:
|
||||
"""Test is_expired with past expiry."""
|
||||
item = WorkingMemoryItem(
|
||||
id=uuid4(),
|
||||
scope_type=ScopeLevel.SESSION,
|
||||
scope_id="sess-123",
|
||||
key="my_key",
|
||||
value="value",
|
||||
expires_at=datetime.now() - timedelta(hours=1),
|
||||
)
|
||||
|
||||
assert item.is_expired() is True
|
||||
|
||||
|
||||
class TestTaskState:
|
||||
"""Tests for TaskState dataclass."""
|
||||
|
||||
def test_create_task_state(self) -> None:
|
||||
"""Test creating a task state."""
|
||||
state = TaskState(
|
||||
task_id="task-123",
|
||||
task_type="code_review",
|
||||
description="Review PR #42",
|
||||
)
|
||||
|
||||
assert state.task_id == "task-123"
|
||||
assert state.task_type == "code_review"
|
||||
assert state.status == "in_progress"
|
||||
assert state.current_step == 0
|
||||
assert state.progress_percent == 0.0
|
||||
|
||||
def test_task_state_with_progress(self) -> None:
|
||||
"""Test task state with progress."""
|
||||
state = TaskState(
|
||||
task_id="task-123",
|
||||
task_type="implementation",
|
||||
description="Implement feature X",
|
||||
current_step=3,
|
||||
total_steps=5,
|
||||
progress_percent=60.0,
|
||||
)
|
||||
|
||||
assert state.current_step == 3
|
||||
assert state.total_steps == 5
|
||||
assert state.progress_percent == 60.0
|
||||
|
||||
|
||||
class TestEpisode:
|
||||
"""Tests for Episode and EpisodeCreate dataclasses."""
|
||||
|
||||
def test_create_episode_data(self) -> None:
|
||||
"""Test creating episode create data."""
|
||||
data = EpisodeCreate(
|
||||
project_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
task_type="bug_fix",
|
||||
task_description="Fix login bug",
|
||||
actions=[{"action": "read_file", "file": "auth.py"}],
|
||||
context_summary="User reported login issues",
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="Fixed by updating validation",
|
||||
duration_seconds=120.5,
|
||||
tokens_used=5000,
|
||||
)
|
||||
|
||||
assert data.task_type == "bug_fix"
|
||||
assert data.outcome == Outcome.SUCCESS
|
||||
assert len(data.actions) == 1
|
||||
assert data.importance_score == 0.5 # Default
|
||||
|
||||
|
||||
class TestFact:
|
||||
"""Tests for Fact and FactCreate dataclasses."""
|
||||
|
||||
def test_create_fact_data(self) -> None:
|
||||
"""Test creating fact create data."""
|
||||
data = FactCreate(
|
||||
subject="FastAPI",
|
||||
predicate="uses",
|
||||
object="Starlette framework",
|
||||
)
|
||||
|
||||
assert data.subject == "FastAPI"
|
||||
assert data.predicate == "uses"
|
||||
assert data.object == "Starlette framework"
|
||||
assert data.confidence == 0.8 # Default
|
||||
assert data.project_id is None # Global fact
|
||||
|
||||
|
||||
class TestProcedure:
|
||||
"""Tests for Procedure and ProcedureCreate dataclasses."""
|
||||
|
||||
def test_create_procedure_data(self) -> None:
|
||||
"""Test creating procedure create data."""
|
||||
data = ProcedureCreate(
|
||||
name="review_pr",
|
||||
trigger_pattern="review pull request",
|
||||
steps=[
|
||||
{"action": "checkout_branch"},
|
||||
{"action": "run_tests"},
|
||||
{"action": "review_changes"},
|
||||
],
|
||||
)
|
||||
|
||||
assert data.name == "review_pr"
|
||||
assert len(data.steps) == 3
|
||||
|
||||
def test_procedure_success_rate(self) -> None:
|
||||
"""Test procedure success rate calculation."""
|
||||
now = datetime.now()
|
||||
procedure = Procedure(
|
||||
id=uuid4(),
|
||||
project_id=None,
|
||||
agent_type_id=None,
|
||||
name="test_proc",
|
||||
trigger_pattern="test",
|
||||
steps=[],
|
||||
success_count=8,
|
||||
failure_count=2,
|
||||
last_used=now,
|
||||
embedding=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
assert procedure.success_rate == 0.8
|
||||
|
||||
def test_procedure_success_rate_zero_uses(self) -> None:
|
||||
"""Test procedure success rate with zero uses."""
|
||||
now = datetime.now()
|
||||
procedure = Procedure(
|
||||
id=uuid4(),
|
||||
project_id=None,
|
||||
agent_type_id=None,
|
||||
name="test_proc",
|
||||
trigger_pattern="test",
|
||||
steps=[],
|
||||
success_count=0,
|
||||
failure_count=0,
|
||||
last_used=None,
|
||||
embedding=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
assert procedure.success_rate == 0.0
|
||||
|
||||
|
||||
class TestStep:
|
||||
"""Tests for Step dataclass."""
|
||||
|
||||
def test_create_step(self) -> None:
|
||||
"""Test creating a step."""
|
||||
step = Step(
|
||||
order=1,
|
||||
action="run_tests",
|
||||
parameters={"verbose": True},
|
||||
expected_outcome="All tests pass",
|
||||
)
|
||||
|
||||
assert step.order == 1
|
||||
assert step.action == "run_tests"
|
||||
assert step.parameters == {"verbose": True}
|
||||
|
||||
|
||||
class TestRetrievalResult:
|
||||
"""Tests for RetrievalResult dataclass."""
|
||||
|
||||
def test_create_retrieval_result(self) -> None:
|
||||
"""Test creating a retrieval result."""
|
||||
result: RetrievalResult[Fact] = RetrievalResult(
|
||||
items=[],
|
||||
total_count=0,
|
||||
query="test query",
|
||||
retrieval_type="semantic",
|
||||
latency_ms=15.5,
|
||||
)
|
||||
|
||||
assert result.query == "test query"
|
||||
assert result.latency_ms == 15.5
|
||||
assert result.metadata == {}
|
||||
|
||||
|
||||
class TestMemoryStats:
|
||||
"""Tests for MemoryStats dataclass."""
|
||||
|
||||
def test_create_memory_stats(self) -> None:
|
||||
"""Test creating memory stats."""
|
||||
stats = MemoryStats(
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
scope_type=ScopeLevel.PROJECT,
|
||||
scope_id="proj-123",
|
||||
item_count=150,
|
||||
total_size_bytes=1048576,
|
||||
oldest_item_age_seconds=86400,
|
||||
newest_item_age_seconds=60,
|
||||
avg_item_size_bytes=6990.5,
|
||||
)
|
||||
|
||||
assert stats.memory_type == MemoryType.EPISODIC
|
||||
assert stats.item_count == 150
|
||||
assert stats.total_size_bytes == 1048576
|
||||
526
docs/architecture/memory-system-plan.md
Normal file
526
docs/architecture/memory-system-plan.md
Normal file
@@ -0,0 +1,526 @@
|
||||
# Agent Memory System - Implementation Plan
|
||||
|
||||
## Issue #62 - Part of Epic #60 (Phase 2: MCP Integration)
|
||||
|
||||
**Branch:** `feature/62-agent-memory-system`
|
||||
**Parent Epic:** #60 [EPIC] Phase 2: MCP Integration
|
||||
**Dependencies:** #56 (LLM Gateway), #57 (Knowledge Base), #61 (Context Management Engine)
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
The Agent Memory System provides multi-tier cognitive memory for AI agents, enabling them to:
|
||||
- Maintain state across sessions (Working Memory)
|
||||
- Learn from past experiences (Episodic Memory)
|
||||
- Store and retrieve facts (Semantic Memory)
|
||||
- Develop and reuse procedures (Procedural Memory)
|
||||
|
||||
### Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Agent Memory System │
|
||||
├─────────────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌─────────────────┐ ┌─────────────────┐ │
|
||||
│ │ Working Memory │───────────────────▶ │ Episodic Memory │ │
|
||||
│ │ (Redis/In-Mem) │ consolidate │ (PostgreSQL) │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ • Current task │ │ • Past sessions │ │
|
||||
│ │ • Variables │ │ • Experiences │ │
|
||||
│ │ • Scratchpad │ │ • Outcomes │ │
|
||||
│ └─────────────────┘ └────────┬────────┘ │
|
||||
│ │ │
|
||||
│ extract │ │
|
||||
│ ▼ │
|
||||
│ ┌─────────────────┐ ┌─────────────────┐ │
|
||||
│ │Procedural Memory│◀─────────────────────│ Semantic Memory │ │
|
||||
│ │ (PostgreSQL) │ learn from │ (PostgreSQL + │ │
|
||||
│ │ │ │ pgvector) │ │
|
||||
│ │ • Procedures │ │ │ │
|
||||
│ │ • Skills │ │ • Facts │ │
|
||||
│ │ • Patterns │ │ • Entities │ │
|
||||
│ └─────────────────┘ │ • Relationships │ │
|
||||
│ └─────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Memory Scoping Hierarchy
|
||||
|
||||
```
|
||||
Global Memory (shared by all)
|
||||
└── Project Memory (per project)
|
||||
└── Agent Type Memory (per agent type)
|
||||
└── Agent Instance Memory (per instance)
|
||||
└── Session Memory (ephemeral)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Sub-Issue Breakdown
|
||||
|
||||
### Phase 1: Foundation (Critical Path)
|
||||
|
||||
#### Sub-Issue #62-1: Project Setup & Core Architecture
|
||||
**Priority:** P0 - Must complete first
|
||||
**Estimated Complexity:** Medium
|
||||
|
||||
**Tasks:**
|
||||
- [ ] Create `backend/app/services/memory/` directory structure
|
||||
- [ ] Create `__init__.py` with public API exports
|
||||
- [ ] Create `config.py` with `MemorySettings` (Pydantic)
|
||||
- [ ] Define base interfaces in `types.py`:
|
||||
- `MemoryItem` - Base class for all memory items
|
||||
- `MemoryScope` - Enum for scoping levels
|
||||
- `MemoryStore` - Abstract base for storage backends
|
||||
- [ ] Create `manager.py` with `MemoryManager` class (facade)
|
||||
- [ ] Create `exceptions.py` with memory-specific errors
|
||||
- [ ] Write ADR-010 documenting memory architecture decisions
|
||||
- [ ] Create dependency injection setup
|
||||
- [ ] Unit tests for configuration and types
|
||||
|
||||
**Deliverables:**
|
||||
- Directory structure matching existing patterns (like `context/`, `safety/`)
|
||||
- Configuration with MEM_ env prefix
|
||||
- Type definitions for all memory concepts
|
||||
- Comprehensive unit tests
|
||||
|
||||
---
|
||||
|
||||
#### Sub-Issue #62-2: Database Schema & Storage Layer
|
||||
**Priority:** P0 - Required for all memory types
|
||||
**Estimated Complexity:** High
|
||||
|
||||
**Database Tables:**
|
||||
|
||||
1. **`working_memory`** - Ephemeral key-value storage
|
||||
- `id` (UUID, PK)
|
||||
- `scope_type` (ENUM: global/project/agent_type/agent_instance/session)
|
||||
- `scope_id` (VARCHAR - the ID for the scope level)
|
||||
- `key` (VARCHAR)
|
||||
- `value` (JSONB)
|
||||
- `expires_at` (TIMESTAMP WITH TZ)
|
||||
- `created_at`, `updated_at`
|
||||
|
||||
2. **`episodes`** - Experiential memories
|
||||
- `id` (UUID, PK)
|
||||
- `project_id` (UUID, FK)
|
||||
- `agent_instance_id` (UUID, FK, nullable)
|
||||
- `agent_type_id` (UUID, FK, nullable)
|
||||
- `session_id` (VARCHAR)
|
||||
- `task_type` (VARCHAR)
|
||||
- `task_description` (TEXT)
|
||||
- `actions` (JSONB)
|
||||
- `context_summary` (TEXT)
|
||||
- `outcome` (ENUM: success/failure/partial)
|
||||
- `outcome_details` (TEXT)
|
||||
- `duration_seconds` (FLOAT)
|
||||
- `tokens_used` (BIGINT)
|
||||
- `lessons_learned` (JSONB - list of strings)
|
||||
- `importance_score` (FLOAT, 0-1)
|
||||
- `embedding` (VECTOR(1536))
|
||||
- `occurred_at` (TIMESTAMP WITH TZ)
|
||||
- `created_at`, `updated_at`
|
||||
|
||||
3. **`facts`** - Semantic knowledge
|
||||
- `id` (UUID, PK)
|
||||
- `project_id` (UUID, FK, nullable - null for global)
|
||||
- `subject` (VARCHAR)
|
||||
- `predicate` (VARCHAR)
|
||||
- `object` (TEXT)
|
||||
- `confidence` (FLOAT, 0-1)
|
||||
- `source_episode_ids` (UUID[])
|
||||
- `first_learned` (TIMESTAMP WITH TZ)
|
||||
- `last_reinforced` (TIMESTAMP WITH TZ)
|
||||
- `reinforcement_count` (INT)
|
||||
- `embedding` (VECTOR(1536))
|
||||
- `created_at`, `updated_at`
|
||||
|
||||
4. **`procedures`** - Learned skills
|
||||
- `id` (UUID, PK)
|
||||
- `project_id` (UUID, FK, nullable)
|
||||
- `agent_type_id` (UUID, FK, nullable)
|
||||
- `name` (VARCHAR)
|
||||
- `trigger_pattern` (TEXT)
|
||||
- `steps` (JSONB)
|
||||
- `success_count` (INT)
|
||||
- `failure_count` (INT)
|
||||
- `last_used` (TIMESTAMP WITH TZ)
|
||||
- `embedding` (VECTOR(1536))
|
||||
- `created_at`, `updated_at`
|
||||
|
||||
5. **`memory_consolidation_log`** - Consolidation tracking
|
||||
- `id` (UUID, PK)
|
||||
- `consolidation_type` (ENUM)
|
||||
- `source_count` (INT)
|
||||
- `result_count` (INT)
|
||||
- `started_at`, `completed_at`
|
||||
- `status` (ENUM: pending/running/completed/failed)
|
||||
- `error` (TEXT, nullable)
|
||||
|
||||
**Tasks:**
|
||||
- [ ] Create SQLAlchemy models in `backend/app/models/memory/`
|
||||
- [ ] Create Alembic migration with all tables
|
||||
- [ ] Add pgvector indexes (HNSW for episodes, facts, procedures)
|
||||
- [ ] Create repository classes in `backend/app/crud/memory/`
|
||||
- [ ] Add composite indexes for common query patterns
|
||||
- [ ] Unit tests for all repositories
|
||||
|
||||
---
|
||||
|
||||
#### Sub-Issue #62-3: Working Memory Implementation
|
||||
**Priority:** P0 - Core functionality
|
||||
**Estimated Complexity:** Medium
|
||||
|
||||
**Components:**
|
||||
- `backend/app/services/memory/working/memory.py` - WorkingMemory class
|
||||
- `backend/app/services/memory/working/storage.py` - Redis + in-memory backend
|
||||
|
||||
**Features:**
|
||||
- [ ] Session-scoped containers with automatic cleanup
|
||||
- [ ] Variable storage (get/set/delete)
|
||||
- [ ] Task state tracking (current step, status, progress)
|
||||
- [ ] Scratchpad for reasoning steps
|
||||
- [ ] Configurable capacity limits
|
||||
- [ ] TTL-based expiration
|
||||
- [ ] Checkpoint/snapshot support for recovery
|
||||
- [ ] Redis primary storage with in-memory fallback
|
||||
|
||||
**API:**
|
||||
```python
|
||||
class WorkingMemory:
|
||||
async def set(self, key: str, value: Any, ttl_seconds: int | None = None) -> None
|
||||
async def get(self, key: str, default: Any = None) -> Any
|
||||
async def delete(self, key: str) -> bool
|
||||
async def exists(self, key: str) -> bool
|
||||
async def list_keys(self, pattern: str = "*") -> list[str]
|
||||
async def get_all(self) -> dict[str, Any]
|
||||
async def clear(self) -> int
|
||||
async def set_task_state(self, state: TaskState) -> None
|
||||
async def get_task_state(self) -> TaskState | None
|
||||
async def append_scratchpad(self, content: str) -> None
|
||||
async def get_scratchpad(self) -> list[str]
|
||||
async def create_checkpoint(self) -> str # Returns checkpoint ID
|
||||
async def restore_checkpoint(self, checkpoint_id: str) -> None
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: Memory Types
|
||||
|
||||
#### Sub-Issue #62-4: Episodic Memory Implementation
|
||||
**Priority:** P1
|
||||
**Estimated Complexity:** High
|
||||
|
||||
**Components:**
|
||||
- `backend/app/services/memory/episodic/memory.py` - EpisodicMemory class
|
||||
- `backend/app/services/memory/episodic/recorder.py` - Episode recording
|
||||
- `backend/app/services/memory/episodic/retrieval.py` - Retrieval strategies
|
||||
|
||||
**Features:**
|
||||
- [ ] Episode recording during agent execution
|
||||
- [ ] Store task completions with context
|
||||
- [ ] Store failures with error context
|
||||
- [ ] Retrieval by semantic similarity (vector search)
|
||||
- [ ] Retrieval by recency
|
||||
- [ ] Retrieval by outcome (success/failure)
|
||||
- [ ] Importance scoring based on outcome significance
|
||||
- [ ] Episode summarization for long-term storage
|
||||
|
||||
**API:**
|
||||
```python
|
||||
class EpisodicMemory:
|
||||
async def record_episode(self, episode: EpisodeCreate) -> Episode
|
||||
async def search_similar(self, query: str, limit: int = 10) -> list[Episode]
|
||||
async def get_recent(self, limit: int = 10, since: datetime | None = None) -> list[Episode]
|
||||
async def get_by_outcome(self, outcome: Outcome, limit: int = 10) -> list[Episode]
|
||||
async def get_by_task_type(self, task_type: str, limit: int = 10) -> list[Episode]
|
||||
async def update_importance(self, episode_id: UUID, score: float) -> None
|
||||
async def summarize_episodes(self, episode_ids: list[UUID]) -> str
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### Sub-Issue #62-5: Semantic Memory Implementation
|
||||
**Priority:** P1
|
||||
**Estimated Complexity:** High
|
||||
|
||||
**Components:**
|
||||
- `backend/app/services/memory/semantic/memory.py` - SemanticMemory class
|
||||
- `backend/app/services/memory/semantic/extraction.py` - Fact extraction from episodes
|
||||
- `backend/app/services/memory/semantic/verification.py` - Fact verification
|
||||
|
||||
**Features:**
|
||||
- [ ] Fact storage with triple format (subject, predicate, object)
|
||||
- [ ] Confidence scoring and decay
|
||||
- [ ] Fact extraction from episodic memory
|
||||
- [ ] Conflict resolution for contradictory facts
|
||||
- [ ] Retrieval by query (semantic search)
|
||||
- [ ] Retrieval by entity (subject or object)
|
||||
- [ ] Source tracking (which episodes contributed)
|
||||
- [ ] Reinforcement on repeated learning
|
||||
|
||||
**API:**
|
||||
```python
|
||||
class SemanticMemory:
|
||||
async def store_fact(self, fact: FactCreate) -> Fact
|
||||
async def search_facts(self, query: str, limit: int = 10) -> list[Fact]
|
||||
async def get_by_entity(self, entity: str, limit: int = 20) -> list[Fact]
|
||||
async def reinforce_fact(self, fact_id: UUID) -> Fact
|
||||
async def deprecate_fact(self, fact_id: UUID, reason: str) -> None
|
||||
async def extract_facts_from_episode(self, episode: Episode) -> list[Fact]
|
||||
async def resolve_conflict(self, fact_ids: list[UUID]) -> Fact
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
#### Sub-Issue #62-6: Procedural Memory Implementation
|
||||
**Priority:** P2
|
||||
**Estimated Complexity:** Medium
|
||||
|
||||
**Components:**
|
||||
- `backend/app/services/memory/procedural/memory.py` - ProceduralMemory class
|
||||
- `backend/app/services/memory/procedural/matching.py` - Procedure matching
|
||||
|
||||
**Features:**
|
||||
- [ ] Procedure recording from successful task patterns
|
||||
- [ ] Trigger pattern matching
|
||||
- [ ] Step-by-step procedure storage
|
||||
- [ ] Success/failure rate tracking
|
||||
- [ ] Procedure suggestion based on context
|
||||
- [ ] Procedure versioning
|
||||
|
||||
**API:**
|
||||
```python
|
||||
class ProceduralMemory:
|
||||
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure
|
||||
async def find_matching(self, context: str, limit: int = 5) -> list[Procedure]
|
||||
async def record_outcome(self, procedure_id: UUID, success: bool) -> None
|
||||
async def get_best_procedure(self, task_type: str) -> Procedure | None
|
||||
async def update_steps(self, procedure_id: UUID, steps: list[Step]) -> Procedure
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: Advanced Features
|
||||
|
||||
#### Sub-Issue #62-7: Memory Scoping
|
||||
**Priority:** P1
|
||||
**Estimated Complexity:** Medium
|
||||
|
||||
**Components:**
|
||||
- `backend/app/services/memory/scoping/scope.py` - Scope management
|
||||
- `backend/app/services/memory/scoping/resolver.py` - Scope resolution
|
||||
|
||||
**Features:**
|
||||
- [ ] Global scope (shared across all)
|
||||
- [ ] Project scope (per project)
|
||||
- [ ] Agent type scope (per agent type)
|
||||
- [ ] Agent instance scope (per instance)
|
||||
- [ ] Session scope (ephemeral)
|
||||
- [ ] Scope inheritance (child sees parent memories)
|
||||
- [ ] Access control policies
|
||||
|
||||
---
|
||||
|
||||
#### Sub-Issue #62-8: Memory Indexing & Retrieval
|
||||
**Priority:** P1
|
||||
**Estimated Complexity:** High
|
||||
|
||||
**Components:**
|
||||
- `backend/app/services/memory/indexing/index.py` - Memory indexer
|
||||
- `backend/app/services/memory/indexing/retrieval.py` - Retrieval engine
|
||||
|
||||
**Features:**
|
||||
- [ ] Vector embeddings for all memory types
|
||||
- [ ] Temporal index (by time)
|
||||
- [ ] Entity index (by entities mentioned)
|
||||
- [ ] Outcome index (by success/failure)
|
||||
- [ ] Hybrid retrieval (vector + filters)
|
||||
- [ ] Relevance scoring
|
||||
- [ ] Retrieval caching
|
||||
|
||||
---
|
||||
|
||||
#### Sub-Issue #62-9: Memory Consolidation
|
||||
**Priority:** P2
|
||||
**Estimated Complexity:** High
|
||||
|
||||
**Components:**
|
||||
- `backend/app/services/memory/consolidation/service.py` - Consolidation service
|
||||
- `backend/app/tasks/memory_consolidation.py` - Celery tasks
|
||||
|
||||
**Features:**
|
||||
- [ ] Working → Episodic transfer (session end)
|
||||
- [ ] Episodic → Semantic extraction (learn facts)
|
||||
- [ ] Episodic → Procedural extraction (learn procedures)
|
||||
- [ ] Nightly consolidation Celery tasks
|
||||
- [ ] Memory pruning (remove low-value)
|
||||
- [ ] Importance-based retention
|
||||
|
||||
---
|
||||
|
||||
### Phase 4: Integration
|
||||
|
||||
#### Sub-Issue #62-10: MCP Tools Definition
|
||||
**Priority:** P0 - Required for agent usage
|
||||
**Estimated Complexity:** Medium
|
||||
|
||||
**MCP Tools:**
|
||||
|
||||
1. **`remember`** - Store in memory
|
||||
```json
|
||||
{
|
||||
"memory_type": "working|episodic|semantic|procedural",
|
||||
"content": "...",
|
||||
"importance": 0.8,
|
||||
"ttl_seconds": 3600
|
||||
}
|
||||
```
|
||||
|
||||
2. **`recall`** - Retrieve from memory
|
||||
```json
|
||||
{
|
||||
"query": "...",
|
||||
"memory_types": ["episodic", "semantic"],
|
||||
"limit": 10,
|
||||
"filters": {"outcome": "success"}
|
||||
}
|
||||
```
|
||||
|
||||
3. **`forget`** - Remove from memory
|
||||
```json
|
||||
{
|
||||
"memory_type": "working",
|
||||
"key": "temp_calculation"
|
||||
}
|
||||
```
|
||||
|
||||
4. **`reflect`** - Analyze patterns
|
||||
```json
|
||||
{
|
||||
"analysis_type": "recent_patterns|success_factors|failure_patterns"
|
||||
}
|
||||
```
|
||||
|
||||
5. **`get_memory_stats`** - Usage statistics
|
||||
6. **`search_procedures`** - Find relevant procedures
|
||||
7. **`record_outcome`** - Record task success/failure
|
||||
|
||||
---
|
||||
|
||||
#### Sub-Issue #62-11: Component Integration
|
||||
**Priority:** P1
|
||||
**Estimated Complexity:** Medium
|
||||
|
||||
**Integrations:**
|
||||
- [ ] Context Engine (#61) - Include relevant memories in context assembly
|
||||
- [ ] Knowledge Base (#57) - Coordinate with KB to avoid duplication
|
||||
- [ ] LLM Gateway (#56) - Use for embedding generation
|
||||
- [ ] Agent lifecycle hooks (spawn, pause, resume, terminate)
|
||||
|
||||
---
|
||||
|
||||
#### Sub-Issue #62-12: Caching Layer
|
||||
**Priority:** P2
|
||||
**Estimated Complexity:** Medium
|
||||
|
||||
**Features:**
|
||||
- [ ] Hot memory caching (frequently accessed)
|
||||
- [ ] Retrieval result caching
|
||||
- [ ] Embedding caching
|
||||
- [ ] Cache invalidation strategies
|
||||
|
||||
---
|
||||
|
||||
### Phase 5: Intelligence & Quality
|
||||
|
||||
#### Sub-Issue #62-13: Memory Reflection
|
||||
**Priority:** P3
|
||||
**Estimated Complexity:** High
|
||||
|
||||
**Features:**
|
||||
- [ ] Pattern detection in episodic memory
|
||||
- [ ] Success/failure factor analysis
|
||||
- [ ] Anomaly detection
|
||||
- [ ] Insights generation
|
||||
|
||||
---
|
||||
|
||||
#### Sub-Issue #62-14: Metrics & Observability
|
||||
**Priority:** P2
|
||||
**Estimated Complexity:** Low
|
||||
|
||||
**Metrics:**
|
||||
- `memory_size_bytes` by type and scope
|
||||
- `memory_operations_total` counter
|
||||
- `memory_retrieval_latency_seconds` histogram
|
||||
- `memory_consolidation_duration_seconds` histogram
|
||||
- `procedure_success_rate` gauge
|
||||
|
||||
---
|
||||
|
||||
#### Sub-Issue #62-15: Documentation & Final Testing
|
||||
**Priority:** P0
|
||||
**Estimated Complexity:** Medium
|
||||
|
||||
**Deliverables:**
|
||||
- [ ] README with architecture overview
|
||||
- [ ] API documentation with examples
|
||||
- [ ] Integration guide
|
||||
- [ ] E2E tests for full memory lifecycle
|
||||
- [ ] Achieve >90% code coverage
|
||||
- [ ] Performance benchmarks
|
||||
|
||||
---
|
||||
|
||||
## Implementation Order
|
||||
|
||||
```
|
||||
Phase 1 (Foundation) - Sequential
|
||||
#62-1 → #62-2 → #62-3
|
||||
|
||||
Phase 2 (Memory Types) - Can parallelize after Phase 1
|
||||
#62-4, #62-5, #62-6 (parallel after #62-3)
|
||||
|
||||
Phase 3 (Advanced) - Sequential within phase
|
||||
#62-7 → #62-8 → #62-9
|
||||
|
||||
Phase 4 (Integration) - After Phase 2
|
||||
#62-10 → #62-11 → #62-12
|
||||
|
||||
Phase 5 (Quality) - Final
|
||||
#62-13, #62-14, #62-15
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Performance Targets
|
||||
|
||||
| Metric | Target | Notes |
|
||||
|--------|--------|-------|
|
||||
| Working memory get/set | <5ms | P95 |
|
||||
| Episodic memory retrieval | <100ms | P95, as per epic |
|
||||
| Semantic memory search | <100ms | P95 |
|
||||
| Procedural memory matching | <50ms | P95 |
|
||||
| Consolidation batch | <30s | Per 1000 episodes |
|
||||
|
||||
---
|
||||
|
||||
## Risk Mitigation
|
||||
|
||||
1. **Embedding costs** - Use caching aggressively, batch embeddings
|
||||
2. **Storage growth** - Implement TTL, pruning, and archival policies
|
||||
3. **Query performance** - HNSW indexes, pagination, query optimization
|
||||
4. **Scope complexity** - Start simple (instance scope only), add hierarchy later
|
||||
|
||||
---
|
||||
|
||||
## Review Checkpoints
|
||||
|
||||
After each sub-issue:
|
||||
1. Run `make validate-all`
|
||||
2. Multi-agent code review
|
||||
3. Verify E2E stack still works
|
||||
4. Commit with granular message
|
||||
Reference in New Issue
Block a user