forked from cardosofelipe/fast-next-template
feat(memory): integrate memory system with context engine (#97)
## Changes ### New Context Type - Add MEMORY to ContextType enum for agent memory context - Create MemoryContext class with subtypes (working, episodic, semantic, procedural) - Factory methods: from_working_memory, from_episodic_memory, from_semantic_memory, from_procedural_memory ### Memory Context Source - MemoryContextSource service fetches relevant memories for context assembly - Configurable fetch limits per memory type - Parallel fetching from all memory types ### Agent Lifecycle Hooks - AgentLifecycleManager handles spawn, pause, resume, terminate events - spawn: Initialize working memory with optional initial state - pause: Create checkpoint of working memory - resume: Restore from checkpoint - terminate: Consolidate working memory to episodic memory - LifecycleHooks for custom extension points ### Context Engine Integration - Add memory_query parameter to assemble_context() - Add session_id and agent_type_id for memory scoping - Memory budget allocation (15% by default) - set_memory_source() for runtime configuration ### Tests - 48 new tests for MemoryContext, MemoryContextSource, and lifecycle hooks - All 108 memory-related tests passing - mypy and ruff checks passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
19
backend/app/services/memory/integration/__init__.py
Normal file
19
backend/app/services/memory/integration/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# app/services/memory/integration/__init__.py
|
||||
"""
|
||||
Memory Integration Module.
|
||||
|
||||
Provides integration between the agent memory system and other Syndarix components:
|
||||
- Context Engine: Memory as context source
|
||||
- Agent Lifecycle: Spawn, pause, resume, terminate hooks
|
||||
"""
|
||||
|
||||
from .context_source import MemoryContextSource, get_memory_context_source
|
||||
from .lifecycle import AgentLifecycleManager, LifecycleHooks, get_lifecycle_manager
|
||||
|
||||
__all__ = [
|
||||
"AgentLifecycleManager",
|
||||
"LifecycleHooks",
|
||||
"MemoryContextSource",
|
||||
"get_lifecycle_manager",
|
||||
"get_memory_context_source",
|
||||
]
|
||||
402
backend/app/services/memory/integration/context_source.py
Normal file
402
backend/app/services/memory/integration/context_source.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# app/services/memory/integration/context_source.py
|
||||
"""
|
||||
Memory Context Source.
|
||||
|
||||
Provides agent memory as a context source for the Context Engine.
|
||||
Retrieves relevant memories based on query and converts them to MemoryContext objects.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.context.types.memory import MemoryContext
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.procedural import ProceduralMemory
|
||||
from app.services.memory.semantic import SemanticMemory
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchConfig:
|
||||
"""Configuration for memory fetching."""
|
||||
|
||||
# Limits per memory type
|
||||
working_limit: int = 10
|
||||
episodic_limit: int = 10
|
||||
semantic_limit: int = 15
|
||||
procedural_limit: int = 5
|
||||
|
||||
# Time ranges
|
||||
episodic_days_back: int = 30
|
||||
min_relevance: float = 0.3
|
||||
|
||||
# Which memory types to include
|
||||
include_working: bool = True
|
||||
include_episodic: bool = True
|
||||
include_semantic: bool = True
|
||||
include_procedural: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchResult:
|
||||
"""Result of memory fetch operation."""
|
||||
|
||||
contexts: list[MemoryContext]
|
||||
by_type: dict[str, int]
|
||||
fetch_time_ms: float
|
||||
query: str
|
||||
|
||||
|
||||
class MemoryContextSource:
|
||||
"""
|
||||
Source for memory context in the Context Engine.
|
||||
|
||||
This service retrieves relevant memories based on a query and
|
||||
converts them to MemoryContext objects for context assembly.
|
||||
It coordinates between all memory types (working, episodic,
|
||||
semantic, procedural) to provide a comprehensive memory context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the memory context source.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
|
||||
# Lazy-initialized memory services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
async def fetch_context(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
config: MemoryFetchConfig | None = None,
|
||||
) -> MemoryFetchResult:
|
||||
"""
|
||||
Fetch relevant memories as context.
|
||||
|
||||
This is the main entry point for the Context Engine integration.
|
||||
It searches across all memory types and returns relevant memories
|
||||
as MemoryContext objects.
|
||||
|
||||
Args:
|
||||
query: Search query for finding relevant memories
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
agent_type_id: Optional agent type scope (for procedural)
|
||||
session_id: Optional session ID (for working memory)
|
||||
config: Optional fetch configuration
|
||||
|
||||
Returns:
|
||||
MemoryFetchResult with contexts and metadata
|
||||
"""
|
||||
config = config or MemoryFetchConfig()
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
by_type: dict[str, int] = {
|
||||
"working": 0,
|
||||
"episodic": 0,
|
||||
"semantic": 0,
|
||||
"procedural": 0,
|
||||
}
|
||||
|
||||
# Fetch from working memory (session-scoped)
|
||||
if config.include_working and session_id:
|
||||
try:
|
||||
working_contexts = await self._fetch_working(
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.working_limit,
|
||||
)
|
||||
contexts.extend(working_contexts)
|
||||
by_type["working"] = len(working_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch working memory: {e}")
|
||||
|
||||
# Fetch from episodic memory
|
||||
if config.include_episodic:
|
||||
try:
|
||||
episodic_contexts = await self._fetch_episodic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.episodic_limit,
|
||||
days_back=config.episodic_days_back,
|
||||
)
|
||||
contexts.extend(episodic_contexts)
|
||||
by_type["episodic"] = len(episodic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch episodic memory: {e}")
|
||||
|
||||
# Fetch from semantic memory
|
||||
if config.include_semantic:
|
||||
try:
|
||||
semantic_contexts = await self._fetch_semantic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=config.semantic_limit,
|
||||
min_relevance=config.min_relevance,
|
||||
)
|
||||
contexts.extend(semantic_contexts)
|
||||
by_type["semantic"] = len(semantic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch semantic memory: {e}")
|
||||
|
||||
# Fetch from procedural memory
|
||||
if config.include_procedural:
|
||||
try:
|
||||
procedural_contexts = await self._fetch_procedural(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=config.procedural_limit,
|
||||
)
|
||||
contexts.extend(procedural_contexts)
|
||||
by_type["procedural"] = len(procedural_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch procedural memory: {e}")
|
||||
|
||||
# Sort by relevance
|
||||
contexts.sort(key=lambda c: c.relevance_score, reverse=True)
|
||||
|
||||
fetch_time = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(contexts)} memory contexts for query '{query[:50]}...' "
|
||||
f"in {fetch_time:.1f}ms"
|
||||
)
|
||||
|
||||
return MemoryFetchResult(
|
||||
contexts=contexts,
|
||||
by_type=by_type,
|
||||
fetch_time_ms=fetch_time,
|
||||
query=query,
|
||||
)
|
||||
|
||||
async def _fetch_working(
|
||||
self,
|
||||
query: str,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from working memory."""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
# Filter keys by query (simple substring match)
|
||||
query_lower = query.lower()
|
||||
matched_keys = [k for k in all_keys if query_lower in k.lower()]
|
||||
|
||||
# If no query match, include all keys (working memory is always relevant)
|
||||
if not matched_keys and query:
|
||||
matched_keys = all_keys
|
||||
|
||||
for key in matched_keys[:limit]:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
async def _fetch_episodic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
days_back: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from episodic memory."""
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Search for similar episodes
|
||||
episodes = await episodic.search_similar(
|
||||
project_id=project_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
# Also get recent episodes if we didn't find enough
|
||||
if len(episodes) < limit // 2:
|
||||
since = datetime.now(UTC) - timedelta(days=days_back)
|
||||
recent = await episodic.get_recent(
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
since=since,
|
||||
)
|
||||
# Deduplicate by ID
|
||||
existing_ids = {e.id for e in episodes}
|
||||
for ep in recent:
|
||||
if ep.id not in existing_ids:
|
||||
episodes.append(ep)
|
||||
if len(episodes) >= limit:
|
||||
break
|
||||
|
||||
return [
|
||||
MemoryContext.from_episodic_memory(ep, query=query)
|
||||
for ep in episodes[:limit]
|
||||
]
|
||||
|
||||
async def _fetch_semantic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
limit: int,
|
||||
min_relevance: float,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from semantic memory."""
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
facts = await semantic.search_facts(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
min_confidence=min_relevance,
|
||||
)
|
||||
|
||||
return [
|
||||
MemoryContext.from_semantic_memory(fact, query=query)
|
||||
for fact in facts
|
||||
]
|
||||
|
||||
async def _fetch_procedural(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from procedural memory."""
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
procedures = await procedural.find_matching(
|
||||
context=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
MemoryContext.from_procedural_memory(proc, query=query)
|
||||
for proc in procedures
|
||||
]
|
||||
|
||||
async def fetch_all_working(
|
||||
self,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[MemoryContext]:
|
||||
"""
|
||||
Fetch all working memory for a session.
|
||||
|
||||
Useful for including entire session state in context.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
|
||||
Returns:
|
||||
List of MemoryContext for all working memory items
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
for key in all_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_memory_context_source(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> MemoryContextSource:
|
||||
"""Create a memory context source instance."""
|
||||
return MemoryContextSource(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
)
|
||||
629
backend/app/services/memory/integration/lifecycle.py
Normal file
629
backend/app/services/memory/integration/lifecycle.py
Normal file
@@ -0,0 +1,629 @@
|
||||
# app/services/memory/integration/lifecycle.py
|
||||
"""
|
||||
Agent Lifecycle Hooks for Memory System.
|
||||
|
||||
Provides memory management hooks for agent lifecycle events:
|
||||
- spawn: Initialize working memory for new agent instance
|
||||
- pause: Checkpoint working memory state
|
||||
- resume: Restore working memory from checkpoint
|
||||
- terminate: Consolidate session to episodic memory
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.types import EpisodeCreate, Outcome
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleEvent:
|
||||
"""Event data for lifecycle hooks."""
|
||||
|
||||
event_type: str # spawn, pause, resume, terminate
|
||||
project_id: UUID
|
||||
agent_instance_id: UUID
|
||||
agent_type_id: UUID | None = None
|
||||
session_id: str | None = None
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleResult:
|
||||
"""Result of a lifecycle operation."""
|
||||
|
||||
success: bool
|
||||
event_type: str
|
||||
message: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
duration_ms: float = 0.0
|
||||
|
||||
|
||||
# Type alias for lifecycle hooks
|
||||
LifecycleHook = Callable[[LifecycleEvent], Coroutine[Any, Any, None]]
|
||||
|
||||
|
||||
class LifecycleHooks:
|
||||
"""
|
||||
Collection of lifecycle hooks.
|
||||
|
||||
Allows registration of custom hooks for lifecycle events.
|
||||
Hooks are called after the core memory operations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize lifecycle hooks."""
|
||||
self._spawn_hooks: list[LifecycleHook] = []
|
||||
self._pause_hooks: list[LifecycleHook] = []
|
||||
self._resume_hooks: list[LifecycleHook] = []
|
||||
self._terminate_hooks: list[LifecycleHook] = []
|
||||
|
||||
def on_spawn(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a spawn hook."""
|
||||
self._spawn_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_pause(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a pause hook."""
|
||||
self._pause_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_resume(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a resume hook."""
|
||||
self._resume_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_terminate(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a terminate hook."""
|
||||
self._terminate_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
async def run_spawn_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all spawn hooks."""
|
||||
for hook in self._spawn_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Spawn hook failed: {e}")
|
||||
|
||||
async def run_pause_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all pause hooks."""
|
||||
for hook in self._pause_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Pause hook failed: {e}")
|
||||
|
||||
async def run_resume_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all resume hooks."""
|
||||
for hook in self._resume_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Resume hook failed: {e}")
|
||||
|
||||
async def run_terminate_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all terminate hooks."""
|
||||
for hook in self._terminate_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Terminate hook failed: {e}")
|
||||
|
||||
|
||||
class AgentLifecycleManager:
|
||||
"""
|
||||
Manager for agent lifecycle and memory integration.
|
||||
|
||||
Handles memory operations during agent lifecycle events:
|
||||
- spawn: Creates new working memory for the session
|
||||
- pause: Saves working memory state to checkpoint
|
||||
- resume: Restores working memory from checkpoint
|
||||
- terminate: Consolidates working memory to episodic memory
|
||||
"""
|
||||
|
||||
# Key prefix for checkpoint storage
|
||||
CHECKPOINT_PREFIX = "__checkpoint__"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
hooks: LifecycleHooks | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the lifecycle manager.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
hooks: Optional lifecycle hooks
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._hooks = hooks or LifecycleHooks()
|
||||
|
||||
# Lazy-initialized services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
@property
|
||||
def hooks(self) -> LifecycleHooks:
|
||||
"""Get the lifecycle hooks."""
|
||||
return self._hooks
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
agent_type_id: UUID | None = None,
|
||||
initial_state: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent spawn - initialize working memory.
|
||||
|
||||
Creates a new working memory instance for the agent session
|
||||
and optionally populates it with initial state.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID for working memory
|
||||
agent_type_id: Optional agent type ID
|
||||
initial_state: Optional initial state to populate
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with spawn outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Create working memory for the session
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Populate initial state if provided
|
||||
items_set = 0
|
||||
if initial_state:
|
||||
for key, value in initial_state.items():
|
||||
await working.set(key, value)
|
||||
items_set += 1
|
||||
|
||||
# Create and run event hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
await self._hooks.run_spawn_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} spawned with session {session_id}, "
|
||||
f"initial state: {items_set} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="spawn",
|
||||
message="Agent spawned successfully",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"initial_items": items_set,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Spawn failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="spawn",
|
||||
message=f"Spawn failed: {e}",
|
||||
)
|
||||
|
||||
async def pause(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
checkpoint_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent pause - checkpoint working memory.
|
||||
|
||||
Saves the current working memory state to a checkpoint
|
||||
that can be restored later with resume().
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
checkpoint_id: Optional checkpoint identifier
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with checkpoint data
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
checkpoint_id = checkpoint_id or f"checkpoint_{int(start_time.timestamp())}"
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Get all current state
|
||||
all_keys = await working.list_keys()
|
||||
# Filter out checkpoint keys
|
||||
state_keys = [k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)]
|
||||
|
||||
state: dict[str, Any] = {}
|
||||
for key in state_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
state[key] = value
|
||||
|
||||
# Store checkpoint
|
||||
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
await working.set(
|
||||
checkpoint_key,
|
||||
{
|
||||
"state": state,
|
||||
"timestamp": start_time.isoformat(),
|
||||
"keys_count": len(state),
|
||||
},
|
||||
ttl_seconds=86400 * 7, # Keep checkpoint for 7 days
|
||||
)
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="pause",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
||||
)
|
||||
await self._hooks.run_pause_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} paused, checkpoint {checkpoint_id} "
|
||||
f"saved with {len(state)} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="pause",
|
||||
message="Agent paused successfully",
|
||||
data={
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"items_saved": len(state),
|
||||
"timestamp": start_time.isoformat(),
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pause failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="pause",
|
||||
message=f"Pause failed: {e}",
|
||||
)
|
||||
|
||||
async def resume(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
checkpoint_id: str,
|
||||
clear_current: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent resume - restore from checkpoint.
|
||||
|
||||
Restores working memory state from a previously saved checkpoint.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
checkpoint_id: Checkpoint to restore from
|
||||
clear_current: Whether to clear current state before restoring
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with restore outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Get checkpoint
|
||||
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
checkpoint = await working.get(checkpoint_key)
|
||||
|
||||
if checkpoint is None:
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message=f"Checkpoint '{checkpoint_id}' not found",
|
||||
)
|
||||
|
||||
# Clear current state if requested
|
||||
if clear_current:
|
||||
all_keys = await working.list_keys()
|
||||
for key in all_keys:
|
||||
if not key.startswith(self.CHECKPOINT_PREFIX):
|
||||
await working.delete(key)
|
||||
|
||||
# Restore state from checkpoint
|
||||
state = checkpoint.get("state", {})
|
||||
items_restored = 0
|
||||
for key, value in state.items():
|
||||
await working.set(key, value)
|
||||
items_restored += 1
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="resume",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
||||
)
|
||||
await self._hooks.run_resume_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} resumed from checkpoint {checkpoint_id}, "
|
||||
f"restored {items_restored} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="resume",
|
||||
message="Agent resumed successfully",
|
||||
data={
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"items_restored": items_restored,
|
||||
"checkpoint_timestamp": checkpoint.get("timestamp"),
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Resume failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message=f"Resume failed: {e}",
|
||||
)
|
||||
|
||||
async def terminate(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
task_description: str | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
lessons_learned: list[str] | None = None,
|
||||
consolidate_to_episodic: bool = True,
|
||||
cleanup_working: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent termination - consolidate to episodic memory.
|
||||
|
||||
Consolidates the session's working memory into an episodic memory
|
||||
entry, then optionally cleans up the working memory.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
task_description: Description of what was accomplished
|
||||
outcome: Task outcome (SUCCESS, FAILURE, PARTIAL)
|
||||
lessons_learned: Optional list of lessons learned
|
||||
consolidate_to_episodic: Whether to create episodic entry
|
||||
cleanup_working: Whether to clear working memory
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with termination outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Gather session state for consolidation
|
||||
all_keys = await working.list_keys()
|
||||
state_keys = [k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)]
|
||||
|
||||
session_state: dict[str, Any] = {}
|
||||
for key in state_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
session_state[key] = value
|
||||
|
||||
episode_id: str | None = None
|
||||
|
||||
# Consolidate to episodic memory
|
||||
if consolidate_to_episodic:
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
description = task_description or f"Session {session_id} completed"
|
||||
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
task_type="session_completion",
|
||||
task_description=description[:500],
|
||||
outcome=outcome,
|
||||
outcome_details=f"Session terminated with {len(session_state)} state items",
|
||||
actions=[
|
||||
{
|
||||
"type": "session_terminate",
|
||||
"state_keys": list(session_state.keys()),
|
||||
"outcome": outcome.value,
|
||||
}
|
||||
],
|
||||
context_summary=str(session_state)[:1000] if session_state else "",
|
||||
lessons_learned=lessons_learned or [],
|
||||
duration_seconds=0.0, # Unknown at this point
|
||||
tokens_used=0,
|
||||
importance_score=0.6, # Moderate importance for session ends
|
||||
)
|
||||
|
||||
episode = await episodic.record_episode(episode_data)
|
||||
episode_id = str(episode.id)
|
||||
|
||||
# Clean up working memory
|
||||
items_cleared = 0
|
||||
if cleanup_working:
|
||||
for key in all_keys:
|
||||
await working.delete(key)
|
||||
items_cleared += 1
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="terminate",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "episode_id": episode_id},
|
||||
)
|
||||
await self._hooks.run_terminate_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} terminated, session {session_id} "
|
||||
f"consolidated to episode {episode_id}"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="terminate",
|
||||
message="Agent terminated successfully",
|
||||
data={
|
||||
"episode_id": episode_id,
|
||||
"state_items_consolidated": len(session_state),
|
||||
"items_cleared": items_cleared,
|
||||
"outcome": outcome.value,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Terminate failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="terminate",
|
||||
message=f"Terminate failed: {e}",
|
||||
)
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List available checkpoints for a session.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
List of checkpoint metadata dicts
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
all_keys = await working.list_keys()
|
||||
checkpoints: list[dict[str, Any]] = []
|
||||
|
||||
for key in all_keys:
|
||||
if key.startswith(self.CHECKPOINT_PREFIX):
|
||||
checkpoint_id = key[len(self.CHECKPOINT_PREFIX):]
|
||||
checkpoint = await working.get(key)
|
||||
if checkpoint:
|
||||
checkpoints.append({
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"timestamp": checkpoint.get("timestamp"),
|
||||
"keys_count": checkpoint.get("keys_count", 0),
|
||||
})
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
checkpoints.sort(
|
||||
key=lambda c: c.get("timestamp", ""),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_lifecycle_manager(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
hooks: LifecycleHooks | None = None,
|
||||
) -> AgentLifecycleManager:
|
||||
"""Create a lifecycle manager instance."""
|
||||
return AgentLifecycleManager(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
hooks=hooks,
|
||||
)
|
||||
Reference in New Issue
Block a user