forked from cardosofelipe/fast-next-template
Auto-fixed linting errors and formatting issues: - Removed unused imports (F401): pytest, Any, AnalysisType, MemoryType, OutcomeType - Removed unused variable (F841): hooks variable in test - Applied consistent formatting across memory service and test files 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
636 lines
21 KiB
Python
636 lines
21 KiB
Python
# app/services/memory/integration/lifecycle.py
|
|
"""
|
|
Agent Lifecycle Hooks for Memory System.
|
|
|
|
Provides memory management hooks for agent lifecycle events:
|
|
- spawn: Initialize working memory for new agent instance
|
|
- pause: Checkpoint working memory state
|
|
- resume: Restore working memory from checkpoint
|
|
- terminate: Consolidate session to episodic memory
|
|
"""
|
|
|
|
import logging
|
|
from collections.abc import Callable, Coroutine
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.services.memory.episodic import EpisodicMemory
|
|
from app.services.memory.types import EpisodeCreate, Outcome
|
|
from app.services.memory.working import WorkingMemory
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class LifecycleEvent:
|
|
"""Event data for lifecycle hooks."""
|
|
|
|
event_type: str # spawn, pause, resume, terminate
|
|
project_id: UUID
|
|
agent_instance_id: UUID
|
|
agent_type_id: UUID | None = None
|
|
session_id: str | None = None
|
|
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class LifecycleResult:
|
|
"""Result of a lifecycle operation."""
|
|
|
|
success: bool
|
|
event_type: str
|
|
message: str | None = None
|
|
data: dict[str, Any] = field(default_factory=dict)
|
|
duration_ms: float = 0.0
|
|
|
|
|
|
# Type alias for lifecycle hooks
|
|
LifecycleHook = Callable[[LifecycleEvent], Coroutine[Any, Any, None]]
|
|
|
|
|
|
class LifecycleHooks:
|
|
"""
|
|
Collection of lifecycle hooks.
|
|
|
|
Allows registration of custom hooks for lifecycle events.
|
|
Hooks are called after the core memory operations.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize lifecycle hooks."""
|
|
self._spawn_hooks: list[LifecycleHook] = []
|
|
self._pause_hooks: list[LifecycleHook] = []
|
|
self._resume_hooks: list[LifecycleHook] = []
|
|
self._terminate_hooks: list[LifecycleHook] = []
|
|
|
|
def on_spawn(self, hook: LifecycleHook) -> LifecycleHook:
|
|
"""Register a spawn hook."""
|
|
self._spawn_hooks.append(hook)
|
|
return hook
|
|
|
|
def on_pause(self, hook: LifecycleHook) -> LifecycleHook:
|
|
"""Register a pause hook."""
|
|
self._pause_hooks.append(hook)
|
|
return hook
|
|
|
|
def on_resume(self, hook: LifecycleHook) -> LifecycleHook:
|
|
"""Register a resume hook."""
|
|
self._resume_hooks.append(hook)
|
|
return hook
|
|
|
|
def on_terminate(self, hook: LifecycleHook) -> LifecycleHook:
|
|
"""Register a terminate hook."""
|
|
self._terminate_hooks.append(hook)
|
|
return hook
|
|
|
|
async def run_spawn_hooks(self, event: LifecycleEvent) -> None:
|
|
"""Run all spawn hooks."""
|
|
for hook in self._spawn_hooks:
|
|
try:
|
|
await hook(event)
|
|
except Exception as e:
|
|
logger.warning(f"Spawn hook failed: {e}")
|
|
|
|
async def run_pause_hooks(self, event: LifecycleEvent) -> None:
|
|
"""Run all pause hooks."""
|
|
for hook in self._pause_hooks:
|
|
try:
|
|
await hook(event)
|
|
except Exception as e:
|
|
logger.warning(f"Pause hook failed: {e}")
|
|
|
|
async def run_resume_hooks(self, event: LifecycleEvent) -> None:
|
|
"""Run all resume hooks."""
|
|
for hook in self._resume_hooks:
|
|
try:
|
|
await hook(event)
|
|
except Exception as e:
|
|
logger.warning(f"Resume hook failed: {e}")
|
|
|
|
async def run_terminate_hooks(self, event: LifecycleEvent) -> None:
|
|
"""Run all terminate hooks."""
|
|
for hook in self._terminate_hooks:
|
|
try:
|
|
await hook(event)
|
|
except Exception as e:
|
|
logger.warning(f"Terminate hook failed: {e}")
|
|
|
|
|
|
class AgentLifecycleManager:
|
|
"""
|
|
Manager for agent lifecycle and memory integration.
|
|
|
|
Handles memory operations during agent lifecycle events:
|
|
- spawn: Creates new working memory for the session
|
|
- pause: Saves working memory state to checkpoint
|
|
- resume: Restores working memory from checkpoint
|
|
- terminate: Consolidates working memory to episodic memory
|
|
"""
|
|
|
|
# Key prefix for checkpoint storage
|
|
CHECKPOINT_PREFIX = "__checkpoint__"
|
|
|
|
def __init__(
|
|
self,
|
|
session: AsyncSession,
|
|
embedding_generator: Any | None = None,
|
|
hooks: LifecycleHooks | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the lifecycle manager.
|
|
|
|
Args:
|
|
session: Database session
|
|
embedding_generator: Optional embedding generator
|
|
hooks: Optional lifecycle hooks
|
|
"""
|
|
self._session = session
|
|
self._embedding_generator = embedding_generator
|
|
self._hooks = hooks or LifecycleHooks()
|
|
|
|
# Lazy-initialized services
|
|
self._episodic: EpisodicMemory | None = None
|
|
|
|
async def _get_episodic(self) -> EpisodicMemory:
|
|
"""Get or create episodic memory service."""
|
|
if self._episodic is None:
|
|
self._episodic = await EpisodicMemory.create(
|
|
self._session,
|
|
self._embedding_generator,
|
|
)
|
|
return self._episodic
|
|
|
|
@property
|
|
def hooks(self) -> LifecycleHooks:
|
|
"""Get the lifecycle hooks."""
|
|
return self._hooks
|
|
|
|
async def spawn(
|
|
self,
|
|
project_id: UUID,
|
|
agent_instance_id: UUID,
|
|
session_id: str,
|
|
agent_type_id: UUID | None = None,
|
|
initial_state: dict[str, Any] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> LifecycleResult:
|
|
"""
|
|
Handle agent spawn - initialize working memory.
|
|
|
|
Creates a new working memory instance for the agent session
|
|
and optionally populates it with initial state.
|
|
|
|
Args:
|
|
project_id: Project scope
|
|
agent_instance_id: Agent instance ID
|
|
session_id: Session ID for working memory
|
|
agent_type_id: Optional agent type ID
|
|
initial_state: Optional initial state to populate
|
|
metadata: Optional metadata for the event
|
|
|
|
Returns:
|
|
LifecycleResult with spawn outcome
|
|
"""
|
|
start_time = datetime.now(UTC)
|
|
|
|
try:
|
|
# Create working memory for the session
|
|
working = await WorkingMemory.for_session(
|
|
session_id=session_id,
|
|
project_id=str(project_id),
|
|
agent_instance_id=str(agent_instance_id),
|
|
)
|
|
|
|
# Populate initial state if provided
|
|
items_set = 0
|
|
if initial_state:
|
|
for key, value in initial_state.items():
|
|
await working.set(key, value)
|
|
items_set += 1
|
|
|
|
# Create and run event hooks
|
|
event = LifecycleEvent(
|
|
event_type="spawn",
|
|
project_id=project_id,
|
|
agent_instance_id=agent_instance_id,
|
|
agent_type_id=agent_type_id,
|
|
session_id=session_id,
|
|
metadata=metadata or {},
|
|
)
|
|
await self._hooks.run_spawn_hooks(event)
|
|
|
|
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
|
|
|
logger.info(
|
|
f"Agent {agent_instance_id} spawned with session {session_id}, "
|
|
f"initial state: {items_set} items"
|
|
)
|
|
|
|
return LifecycleResult(
|
|
success=True,
|
|
event_type="spawn",
|
|
message="Agent spawned successfully",
|
|
data={
|
|
"session_id": session_id,
|
|
"initial_items": items_set,
|
|
},
|
|
duration_ms=duration_ms,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Spawn failed for agent {agent_instance_id}: {e}")
|
|
return LifecycleResult(
|
|
success=False,
|
|
event_type="spawn",
|
|
message=f"Spawn failed: {e}",
|
|
)
|
|
|
|
async def pause(
|
|
self,
|
|
project_id: UUID,
|
|
agent_instance_id: UUID,
|
|
session_id: str,
|
|
checkpoint_id: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> LifecycleResult:
|
|
"""
|
|
Handle agent pause - checkpoint working memory.
|
|
|
|
Saves the current working memory state to a checkpoint
|
|
that can be restored later with resume().
|
|
|
|
Args:
|
|
project_id: Project scope
|
|
agent_instance_id: Agent instance ID
|
|
session_id: Session ID
|
|
checkpoint_id: Optional checkpoint identifier
|
|
metadata: Optional metadata for the event
|
|
|
|
Returns:
|
|
LifecycleResult with checkpoint data
|
|
"""
|
|
start_time = datetime.now(UTC)
|
|
checkpoint_id = checkpoint_id or f"checkpoint_{int(start_time.timestamp())}"
|
|
|
|
try:
|
|
working = await WorkingMemory.for_session(
|
|
session_id=session_id,
|
|
project_id=str(project_id),
|
|
agent_instance_id=str(agent_instance_id),
|
|
)
|
|
|
|
# Get all current state
|
|
all_keys = await working.list_keys()
|
|
# Filter out checkpoint keys
|
|
state_keys = [
|
|
k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)
|
|
]
|
|
|
|
state: dict[str, Any] = {}
|
|
for key in state_keys:
|
|
value = await working.get(key)
|
|
if value is not None:
|
|
state[key] = value
|
|
|
|
# Store checkpoint
|
|
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
|
await working.set(
|
|
checkpoint_key,
|
|
{
|
|
"state": state,
|
|
"timestamp": start_time.isoformat(),
|
|
"keys_count": len(state),
|
|
},
|
|
ttl_seconds=86400 * 7, # Keep checkpoint for 7 days
|
|
)
|
|
|
|
# Run hooks
|
|
event = LifecycleEvent(
|
|
event_type="pause",
|
|
project_id=project_id,
|
|
agent_instance_id=agent_instance_id,
|
|
session_id=session_id,
|
|
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
|
)
|
|
await self._hooks.run_pause_hooks(event)
|
|
|
|
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
|
|
|
logger.info(
|
|
f"Agent {agent_instance_id} paused, checkpoint {checkpoint_id} "
|
|
f"saved with {len(state)} items"
|
|
)
|
|
|
|
return LifecycleResult(
|
|
success=True,
|
|
event_type="pause",
|
|
message="Agent paused successfully",
|
|
data={
|
|
"checkpoint_id": checkpoint_id,
|
|
"items_saved": len(state),
|
|
"timestamp": start_time.isoformat(),
|
|
},
|
|
duration_ms=duration_ms,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Pause failed for agent {agent_instance_id}: {e}")
|
|
return LifecycleResult(
|
|
success=False,
|
|
event_type="pause",
|
|
message=f"Pause failed: {e}",
|
|
)
|
|
|
|
async def resume(
|
|
self,
|
|
project_id: UUID,
|
|
agent_instance_id: UUID,
|
|
session_id: str,
|
|
checkpoint_id: str,
|
|
clear_current: bool = True,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> LifecycleResult:
|
|
"""
|
|
Handle agent resume - restore from checkpoint.
|
|
|
|
Restores working memory state from a previously saved checkpoint.
|
|
|
|
Args:
|
|
project_id: Project scope
|
|
agent_instance_id: Agent instance ID
|
|
session_id: Session ID
|
|
checkpoint_id: Checkpoint to restore from
|
|
clear_current: Whether to clear current state before restoring
|
|
metadata: Optional metadata for the event
|
|
|
|
Returns:
|
|
LifecycleResult with restore outcome
|
|
"""
|
|
start_time = datetime.now(UTC)
|
|
|
|
try:
|
|
working = await WorkingMemory.for_session(
|
|
session_id=session_id,
|
|
project_id=str(project_id),
|
|
agent_instance_id=str(agent_instance_id),
|
|
)
|
|
|
|
# Get checkpoint
|
|
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
|
checkpoint = await working.get(checkpoint_key)
|
|
|
|
if checkpoint is None:
|
|
return LifecycleResult(
|
|
success=False,
|
|
event_type="resume",
|
|
message=f"Checkpoint '{checkpoint_id}' not found",
|
|
)
|
|
|
|
# Clear current state if requested
|
|
if clear_current:
|
|
all_keys = await working.list_keys()
|
|
for key in all_keys:
|
|
if not key.startswith(self.CHECKPOINT_PREFIX):
|
|
await working.delete(key)
|
|
|
|
# Restore state from checkpoint
|
|
state = checkpoint.get("state", {})
|
|
items_restored = 0
|
|
for key, value in state.items():
|
|
await working.set(key, value)
|
|
items_restored += 1
|
|
|
|
# Run hooks
|
|
event = LifecycleEvent(
|
|
event_type="resume",
|
|
project_id=project_id,
|
|
agent_instance_id=agent_instance_id,
|
|
session_id=session_id,
|
|
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
|
)
|
|
await self._hooks.run_resume_hooks(event)
|
|
|
|
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
|
|
|
logger.info(
|
|
f"Agent {agent_instance_id} resumed from checkpoint {checkpoint_id}, "
|
|
f"restored {items_restored} items"
|
|
)
|
|
|
|
return LifecycleResult(
|
|
success=True,
|
|
event_type="resume",
|
|
message="Agent resumed successfully",
|
|
data={
|
|
"checkpoint_id": checkpoint_id,
|
|
"items_restored": items_restored,
|
|
"checkpoint_timestamp": checkpoint.get("timestamp"),
|
|
},
|
|
duration_ms=duration_ms,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Resume failed for agent {agent_instance_id}: {e}")
|
|
return LifecycleResult(
|
|
success=False,
|
|
event_type="resume",
|
|
message=f"Resume failed: {e}",
|
|
)
|
|
|
|
async def terminate(
|
|
self,
|
|
project_id: UUID,
|
|
agent_instance_id: UUID,
|
|
session_id: str,
|
|
task_description: str | None = None,
|
|
outcome: Outcome = Outcome.SUCCESS,
|
|
lessons_learned: list[str] | None = None,
|
|
consolidate_to_episodic: bool = True,
|
|
cleanup_working: bool = True,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> LifecycleResult:
|
|
"""
|
|
Handle agent termination - consolidate to episodic memory.
|
|
|
|
Consolidates the session's working memory into an episodic memory
|
|
entry, then optionally cleans up the working memory.
|
|
|
|
Args:
|
|
project_id: Project scope
|
|
agent_instance_id: Agent instance ID
|
|
session_id: Session ID
|
|
task_description: Description of what was accomplished
|
|
outcome: Task outcome (SUCCESS, FAILURE, PARTIAL)
|
|
lessons_learned: Optional list of lessons learned
|
|
consolidate_to_episodic: Whether to create episodic entry
|
|
cleanup_working: Whether to clear working memory
|
|
metadata: Optional metadata for the event
|
|
|
|
Returns:
|
|
LifecycleResult with termination outcome
|
|
"""
|
|
start_time = datetime.now(UTC)
|
|
|
|
try:
|
|
working = await WorkingMemory.for_session(
|
|
session_id=session_id,
|
|
project_id=str(project_id),
|
|
agent_instance_id=str(agent_instance_id),
|
|
)
|
|
|
|
# Gather session state for consolidation
|
|
all_keys = await working.list_keys()
|
|
state_keys = [
|
|
k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)
|
|
]
|
|
|
|
session_state: dict[str, Any] = {}
|
|
for key in state_keys:
|
|
value = await working.get(key)
|
|
if value is not None:
|
|
session_state[key] = value
|
|
|
|
episode_id: str | None = None
|
|
|
|
# Consolidate to episodic memory
|
|
if consolidate_to_episodic:
|
|
episodic = await self._get_episodic()
|
|
|
|
description = task_description or f"Session {session_id} completed"
|
|
|
|
episode_data = EpisodeCreate(
|
|
project_id=project_id,
|
|
agent_instance_id=agent_instance_id,
|
|
session_id=session_id,
|
|
task_type="session_completion",
|
|
task_description=description[:500],
|
|
outcome=outcome,
|
|
outcome_details=f"Session terminated with {len(session_state)} state items",
|
|
actions=[
|
|
{
|
|
"type": "session_terminate",
|
|
"state_keys": list(session_state.keys()),
|
|
"outcome": outcome.value,
|
|
}
|
|
],
|
|
context_summary=str(session_state)[:1000] if session_state else "",
|
|
lessons_learned=lessons_learned or [],
|
|
duration_seconds=0.0, # Unknown at this point
|
|
tokens_used=0,
|
|
importance_score=0.6, # Moderate importance for session ends
|
|
)
|
|
|
|
episode = await episodic.record_episode(episode_data)
|
|
episode_id = str(episode.id)
|
|
|
|
# Clean up working memory
|
|
items_cleared = 0
|
|
if cleanup_working:
|
|
for key in all_keys:
|
|
await working.delete(key)
|
|
items_cleared += 1
|
|
|
|
# Run hooks
|
|
event = LifecycleEvent(
|
|
event_type="terminate",
|
|
project_id=project_id,
|
|
agent_instance_id=agent_instance_id,
|
|
session_id=session_id,
|
|
metadata={**(metadata or {}), "episode_id": episode_id},
|
|
)
|
|
await self._hooks.run_terminate_hooks(event)
|
|
|
|
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
|
|
|
logger.info(
|
|
f"Agent {agent_instance_id} terminated, session {session_id} "
|
|
f"consolidated to episode {episode_id}"
|
|
)
|
|
|
|
return LifecycleResult(
|
|
success=True,
|
|
event_type="terminate",
|
|
message="Agent terminated successfully",
|
|
data={
|
|
"episode_id": episode_id,
|
|
"state_items_consolidated": len(session_state),
|
|
"items_cleared": items_cleared,
|
|
"outcome": outcome.value,
|
|
},
|
|
duration_ms=duration_ms,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Terminate failed for agent {agent_instance_id}: {e}")
|
|
return LifecycleResult(
|
|
success=False,
|
|
event_type="terminate",
|
|
message=f"Terminate failed: {e}",
|
|
)
|
|
|
|
async def list_checkpoints(
|
|
self,
|
|
project_id: UUID,
|
|
agent_instance_id: UUID,
|
|
session_id: str,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
List available checkpoints for a session.
|
|
|
|
Args:
|
|
project_id: Project scope
|
|
agent_instance_id: Agent instance ID
|
|
session_id: Session ID
|
|
|
|
Returns:
|
|
List of checkpoint metadata dicts
|
|
"""
|
|
working = await WorkingMemory.for_session(
|
|
session_id=session_id,
|
|
project_id=str(project_id),
|
|
agent_instance_id=str(agent_instance_id),
|
|
)
|
|
|
|
all_keys = await working.list_keys()
|
|
checkpoints: list[dict[str, Any]] = []
|
|
|
|
for key in all_keys:
|
|
if key.startswith(self.CHECKPOINT_PREFIX):
|
|
checkpoint_id = key[len(self.CHECKPOINT_PREFIX) :]
|
|
checkpoint = await working.get(key)
|
|
if checkpoint:
|
|
checkpoints.append(
|
|
{
|
|
"checkpoint_id": checkpoint_id,
|
|
"timestamp": checkpoint.get("timestamp"),
|
|
"keys_count": checkpoint.get("keys_count", 0),
|
|
}
|
|
)
|
|
|
|
# Sort by timestamp (newest first)
|
|
checkpoints.sort(
|
|
key=lambda c: c.get("timestamp", ""),
|
|
reverse=True,
|
|
)
|
|
|
|
return checkpoints
|
|
|
|
|
|
# Factory function
|
|
async def get_lifecycle_manager(
|
|
session: AsyncSession,
|
|
embedding_generator: Any | None = None,
|
|
hooks: LifecycleHooks | None = None,
|
|
) -> AgentLifecycleManager:
|
|
"""Create a lifecycle manager instance."""
|
|
return AgentLifecycleManager(
|
|
session=session,
|
|
embedding_generator=embedding_generator,
|
|
hooks=hooks,
|
|
)
|