feat(memory): integrate memory system with context engine (#97)

## Changes

### New Context Type
- Add MEMORY to ContextType enum for agent memory context
- Create MemoryContext class with subtypes (working, episodic, semantic, procedural)
- Factory methods: from_working_memory, from_episodic_memory, from_semantic_memory, from_procedural_memory

### Memory Context Source
- MemoryContextSource service fetches relevant memories for context assembly
- Configurable fetch limits per memory type
- Parallel fetching from all memory types

### Agent Lifecycle Hooks
- AgentLifecycleManager handles spawn, pause, resume, terminate events
- spawn: Initialize working memory with optional initial state
- pause: Create checkpoint of working memory
- resume: Restore from checkpoint
- terminate: Consolidate working memory to episodic memory
- LifecycleHooks for custom extension points

### Context Engine Integration
- Add memory_query parameter to assemble_context()
- Add session_id and agent_type_id for memory scoping
- Memory budget allocation (15% by default)
- set_memory_source() for runtime configuration

### Tests
- 48 new tests for MemoryContext, MemoryContextSource, and lifecycle hooks
- All 108 memory-related tests passing
- mypy and ruff checks passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 03:49:22 +01:00
parent 0b24d4c6cc
commit 30e5c68304
13 changed files with 2509 additions and 6 deletions

View File

@@ -114,6 +114,8 @@ from .types import (
ContextType,
ConversationContext,
KnowledgeContext,
MemoryContext,
MemorySubtype,
MessageRole,
SystemContext,
TaskComplexity,
@@ -149,6 +151,8 @@ __all__ = [
"FormattingError",
"InvalidContextError",
"KnowledgeContext",
"MemoryContext",
"MemorySubtype",
"MessageRole",
"ModelAdapter",
"OpenAIAdapter",

View File

@@ -30,6 +30,7 @@ class TokenBudget:
knowledge: int = 0
conversation: int = 0
tools: int = 0
memory: int = 0 # Agent memory (working, episodic, semantic, procedural)
response_reserve: int = 0
buffer: int = 0
@@ -60,6 +61,7 @@ class TokenBudget:
"knowledge": self.knowledge,
"conversation": self.conversation,
"tool": self.tools,
"memory": self.memory,
}
return allocation_map.get(context_type, 0)
@@ -211,6 +213,7 @@ class TokenBudget:
"knowledge": self.knowledge,
"conversation": self.conversation,
"tools": self.tools,
"memory": self.memory,
"response_reserve": self.response_reserve,
"buffer": self.buffer,
},
@@ -264,9 +267,10 @@ class BudgetAllocator:
total=total_tokens,
system=int(total_tokens * alloc.get("system", 0.05)),
task=int(total_tokens * alloc.get("task", 0.10)),
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
knowledge=int(total_tokens * alloc.get("knowledge", 0.30)),
conversation=int(total_tokens * alloc.get("conversation", 0.15)),
tools=int(total_tokens * alloc.get("tools", 0.05)),
memory=int(total_tokens * alloc.get("memory", 0.15)),
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
)
@@ -317,6 +321,8 @@ class BudgetAllocator:
budget.conversation = max(0, budget.conversation + actual_adjustment)
elif context_type == "tool":
budget.tools = max(0, budget.tools + actual_adjustment)
elif context_type == "memory":
budget.memory = max(0, budget.memory + actual_adjustment)
return budget
@@ -338,7 +344,7 @@ class BudgetAllocator:
Rebalanced budget
"""
if prioritize is None:
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
prioritize = [ContextType.KNOWLEDGE, ContextType.MEMORY, ContextType.TASK, ContextType.SYSTEM]
# Calculate unused tokens per type
unused: dict[str, int] = {}

View File

@@ -7,6 +7,7 @@ Provides a high-level API for assembling optimized context for LLM requests.
import logging
from typing import TYPE_CHECKING, Any
from uuid import UUID
from .assembly import ContextPipeline
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
@@ -20,6 +21,7 @@ from .types import (
BaseContext,
ConversationContext,
KnowledgeContext,
MemoryContext,
MessageRole,
SystemContext,
TaskContext,
@@ -30,6 +32,7 @@ if TYPE_CHECKING:
from redis.asyncio import Redis
from app.services.mcp.client_manager import MCPClientManager
from app.services.memory.integration import MemoryContextSource
logger = logging.getLogger(__name__)
@@ -64,6 +67,7 @@ class ContextEngine:
mcp_manager: "MCPClientManager | None" = None,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
memory_source: "MemoryContextSource | None" = None,
) -> None:
"""
Initialize the context engine.
@@ -72,9 +76,11 @@ class ContextEngine:
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
redis: Redis connection for caching
settings: Context settings
memory_source: Optional memory context source for agent memory
"""
self._mcp = mcp_manager
self._settings = settings or get_context_settings()
self._memory_source = memory_source
# Initialize components
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
@@ -115,6 +121,15 @@ class ContextEngine:
"""
self._cache.set_redis(redis)
def set_memory_source(self, memory_source: "MemoryContextSource") -> None:
"""
Set memory context source for agent memory integration.
Args:
memory_source: Memory context source
"""
self._memory_source = memory_source
async def assemble_context(
self,
project_id: str,
@@ -126,6 +141,10 @@ class ContextEngine:
task_description: str | None = None,
knowledge_query: str | None = None,
knowledge_limit: int = 10,
memory_query: str | None = None,
memory_limit: int = 20,
session_id: str | None = None,
agent_type_id: str | None = None,
conversation_history: list[dict[str, str]] | None = None,
tool_results: list[dict[str, Any]] | None = None,
custom_contexts: list[BaseContext] | None = None,
@@ -151,6 +170,10 @@ class ContextEngine:
task_description: Current task description
knowledge_query: Query for knowledge base search
knowledge_limit: Max number of knowledge results
memory_query: Query for agent memory search
memory_limit: Max number of memory results
session_id: Session ID for working memory access
agent_type_id: Agent type ID for procedural memory
conversation_history: List of {"role": str, "content": str}
tool_results: List of tool results to include
custom_contexts: Additional custom contexts
@@ -197,15 +220,27 @@ class ContextEngine:
)
contexts.extend(knowledge_contexts)
# 4. Conversation history
# 4. Memory context from Agent Memory System
if memory_query and self._memory_source:
memory_contexts = await self._fetch_memory(
project_id=project_id,
agent_id=agent_id,
query=memory_query,
limit=memory_limit,
session_id=session_id,
agent_type_id=agent_type_id,
)
contexts.extend(memory_contexts)
# 5. Conversation history
if conversation_history:
contexts.extend(self._convert_conversation(conversation_history))
# 5. Tool results
# 6. Tool results
if tool_results:
contexts.extend(self._convert_tool_results(tool_results))
# 6. Custom contexts
# 7. Custom contexts
if custom_contexts:
contexts.extend(custom_contexts)
@@ -308,6 +343,65 @@ class ContextEngine:
logger.warning(f"Failed to fetch knowledge: {e}")
return []
async def _fetch_memory(
self,
project_id: str,
agent_id: str,
query: str,
limit: int = 20,
session_id: str | None = None,
agent_type_id: str | None = None,
) -> list[MemoryContext]:
"""
Fetch relevant memories from Agent Memory System.
Args:
project_id: Project identifier
agent_id: Agent identifier
query: Search query
limit: Maximum results
session_id: Session ID for working memory
agent_type_id: Agent type ID for procedural memory
Returns:
List of MemoryContext instances
"""
if not self._memory_source:
return []
try:
# Import here to avoid circular imports
# Configure fetch limits
from app.services.memory.integration.context_source import MemoryFetchConfig
config = MemoryFetchConfig(
working_limit=min(limit // 4, 5),
episodic_limit=min(limit // 2, 10),
semantic_limit=min(limit // 2, 10),
procedural_limit=min(limit // 4, 5),
include_working=session_id is not None,
)
result = await self._memory_source.fetch_context(
query=query,
project_id=UUID(project_id),
agent_instance_id=UUID(agent_id) if agent_id else None,
agent_type_id=UUID(agent_type_id) if agent_type_id else None,
session_id=session_id,
config=config,
)
logger.debug(
f"Fetched {len(result.contexts)} memory contexts for query: {query}, "
f"by_type: {result.by_type}"
)
return result.contexts[:limit]
except Exception as e:
logger.warning(f"Failed to fetch memory: {e}")
return []
def _convert_conversation(
self,
history: list[dict[str, str]],
@@ -466,6 +560,7 @@ def create_context_engine(
mcp_manager: "MCPClientManager | None" = None,
redis: "Redis | None" = None,
settings: ContextSettings | None = None,
memory_source: "MemoryContextSource | None" = None,
) -> ContextEngine:
"""
Create a context engine instance.
@@ -474,6 +569,7 @@ def create_context_engine(
mcp_manager: MCP client manager
redis: Redis connection
settings: Context settings
memory_source: Optional memory context source
Returns:
Configured ContextEngine instance
@@ -482,4 +578,5 @@ def create_context_engine(
mcp_manager=mcp_manager,
redis=redis,
settings=settings,
memory_source=memory_source,
)

View File

@@ -15,6 +15,10 @@ from .conversation import (
MessageRole,
)
from .knowledge import KnowledgeContext
from .memory import (
MemoryContext,
MemorySubtype,
)
from .system import SystemContext
from .task import (
TaskComplexity,
@@ -33,6 +37,8 @@ __all__ = [
"ContextType",
"ConversationContext",
"KnowledgeContext",
"MemoryContext",
"MemorySubtype",
"MessageRole",
"SystemContext",
"TaskComplexity",

View File

@@ -26,6 +26,7 @@ class ContextType(str, Enum):
KNOWLEDGE = "knowledge"
CONVERSATION = "conversation"
TOOL = "tool"
MEMORY = "memory" # Agent memory (working, episodic, semantic, procedural)
@classmethod
def from_string(cls, value: str) -> "ContextType":

View File

@@ -0,0 +1,282 @@
"""
Memory Context Type.
Represents agent memory as context for LLM requests.
Includes working, episodic, semantic, and procedural memories.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from .base import BaseContext, ContextPriority, ContextType
class MemorySubtype(str, Enum):
"""Types of agent memory."""
WORKING = "working" # Session-scoped temporary data
EPISODIC = "episodic" # Task history and outcomes
SEMANTIC = "semantic" # Facts and knowledge
PROCEDURAL = "procedural" # Learned procedures
@dataclass(eq=False)
class MemoryContext(BaseContext):
"""
Context from agent memory system.
Memory context represents data retrieved from the agent
memory system, including:
- Working memory: Current session state
- Episodic memory: Past task experiences
- Semantic memory: Learned facts and knowledge
- Procedural memory: Known procedures and workflows
Each memory item includes relevance scoring from search.
"""
# Memory-specific fields
memory_subtype: MemorySubtype = field(default=MemorySubtype.EPISODIC)
memory_id: str | None = field(default=None)
relevance_score: float = field(default=0.0)
importance: float = field(default=0.5)
search_query: str = field(default="")
# Type-specific fields (populated based on memory_subtype)
key: str | None = field(default=None) # For working memory
task_type: str | None = field(default=None) # For episodic
outcome: str | None = field(default=None) # For episodic
subject: str | None = field(default=None) # For semantic
predicate: str | None = field(default=None) # For semantic
object_value: str | None = field(default=None) # For semantic
trigger: str | None = field(default=None) # For procedural
success_rate: float | None = field(default=None) # For procedural
def get_type(self) -> ContextType:
"""Return MEMORY context type."""
return ContextType.MEMORY
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary with memory-specific fields."""
base = super().to_dict()
base.update(
{
"memory_subtype": self.memory_subtype.value,
"memory_id": self.memory_id,
"relevance_score": self.relevance_score,
"importance": self.importance,
"search_query": self.search_query,
"key": self.key,
"task_type": self.task_type,
"outcome": self.outcome,
"subject": self.subject,
"predicate": self.predicate,
"object_value": self.object_value,
"trigger": self.trigger,
"success_rate": self.success_rate,
}
)
return base
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MemoryContext":
"""Create MemoryContext from dictionary."""
return cls(
id=data.get("id", ""),
content=data["content"],
source=data["source"],
timestamp=datetime.fromisoformat(data["timestamp"])
if isinstance(data.get("timestamp"), str)
else data.get("timestamp", datetime.now(UTC)),
priority=data.get("priority", ContextPriority.NORMAL.value),
metadata=data.get("metadata", {}),
memory_subtype=MemorySubtype(data.get("memory_subtype", "episodic")),
memory_id=data.get("memory_id"),
relevance_score=data.get("relevance_score", 0.0),
importance=data.get("importance", 0.5),
search_query=data.get("search_query", ""),
key=data.get("key"),
task_type=data.get("task_type"),
outcome=data.get("outcome"),
subject=data.get("subject"),
predicate=data.get("predicate"),
object_value=data.get("object_value"),
trigger=data.get("trigger"),
success_rate=data.get("success_rate"),
)
@classmethod
def from_working_memory(
cls,
key: str,
value: Any,
source: str = "working_memory",
query: str = "",
) -> "MemoryContext":
"""
Create MemoryContext from working memory entry.
Args:
key: Working memory key
value: Value stored at key
source: Source identifier
query: Search query used
Returns:
MemoryContext instance
"""
return cls(
content=str(value),
source=source,
memory_subtype=MemorySubtype.WORKING,
key=key,
relevance_score=1.0, # Working memory is always relevant
importance=0.8, # Higher importance for current session state
search_query=query,
priority=ContextPriority.HIGH.value,
)
@classmethod
def from_episodic_memory(
cls,
episode: Any,
query: str = "",
) -> "MemoryContext":
"""
Create MemoryContext from episodic memory episode.
Args:
episode: Episode object from episodic memory
query: Search query used
Returns:
MemoryContext instance
"""
outcome_val = None
if hasattr(episode, "outcome") and episode.outcome:
outcome_val = (
episode.outcome.value
if hasattr(episode.outcome, "value")
else str(episode.outcome)
)
return cls(
content=episode.task_description,
source=f"episodic:{episode.id}",
memory_subtype=MemorySubtype.EPISODIC,
memory_id=str(episode.id),
relevance_score=getattr(episode, "importance_score", 0.5),
importance=getattr(episode, "importance_score", 0.5),
search_query=query,
task_type=getattr(episode, "task_type", None),
outcome=outcome_val,
metadata={
"session_id": getattr(episode, "session_id", None),
"occurred_at": episode.occurred_at.isoformat()
if hasattr(episode, "occurred_at") and episode.occurred_at
else None,
"lessons_learned": getattr(episode, "lessons_learned", []),
},
)
@classmethod
def from_semantic_memory(
cls,
fact: Any,
query: str = "",
) -> "MemoryContext":
"""
Create MemoryContext from semantic memory fact.
Args:
fact: Fact object from semantic memory
query: Search query used
Returns:
MemoryContext instance
"""
triple = f"{fact.subject} {fact.predicate} {fact.object}"
return cls(
content=triple,
source=f"semantic:{fact.id}",
memory_subtype=MemorySubtype.SEMANTIC,
memory_id=str(fact.id),
relevance_score=getattr(fact, "confidence", 0.5),
importance=getattr(fact, "confidence", 0.5),
search_query=query,
subject=fact.subject,
predicate=fact.predicate,
object_value=fact.object,
priority=ContextPriority.NORMAL.value,
)
@classmethod
def from_procedural_memory(
cls,
procedure: Any,
query: str = "",
) -> "MemoryContext":
"""
Create MemoryContext from procedural memory procedure.
Args:
procedure: Procedure object from procedural memory
query: Search query used
Returns:
MemoryContext instance
"""
# Format steps as content
steps = getattr(procedure, "steps", [])
steps_content = "\n".join(
f" {i + 1}. {step.get('action', step) if isinstance(step, dict) else step}"
for i, step in enumerate(steps)
)
content = f"Procedure: {procedure.name}\nTrigger: {procedure.trigger_pattern}\nSteps:\n{steps_content}"
return cls(
content=content,
source=f"procedural:{procedure.id}",
memory_subtype=MemorySubtype.PROCEDURAL,
memory_id=str(procedure.id),
relevance_score=getattr(procedure, "success_rate", 0.5),
importance=0.7, # Procedures are moderately important
search_query=query,
trigger=procedure.trigger_pattern,
success_rate=getattr(procedure, "success_rate", None),
metadata={
"steps_count": len(steps),
"execution_count": getattr(procedure, "success_count", 0)
+ getattr(procedure, "failure_count", 0),
},
)
def is_working_memory(self) -> bool:
"""Check if this is working memory."""
return self.memory_subtype == MemorySubtype.WORKING
def is_episodic_memory(self) -> bool:
"""Check if this is episodic memory."""
return self.memory_subtype == MemorySubtype.EPISODIC
def is_semantic_memory(self) -> bool:
"""Check if this is semantic memory."""
return self.memory_subtype == MemorySubtype.SEMANTIC
def is_procedural_memory(self) -> bool:
"""Check if this is procedural memory."""
return self.memory_subtype == MemorySubtype.PROCEDURAL
def get_formatted_source(self) -> str:
"""
Get a formatted source string for display.
Returns:
Formatted source string
"""
parts = [f"[{self.memory_subtype.value}]", self.source]
if self.memory_id:
parts.append(f"({self.memory_id[:8]}...)")
return " ".join(parts)

View File

@@ -0,0 +1,19 @@
# app/services/memory/integration/__init__.py
"""
Memory Integration Module.
Provides integration between the agent memory system and other Syndarix components:
- Context Engine: Memory as context source
- Agent Lifecycle: Spawn, pause, resume, terminate hooks
"""
from .context_source import MemoryContextSource, get_memory_context_source
from .lifecycle import AgentLifecycleManager, LifecycleHooks, get_lifecycle_manager
__all__ = [
"AgentLifecycleManager",
"LifecycleHooks",
"MemoryContextSource",
"get_lifecycle_manager",
"get_memory_context_source",
]

View File

@@ -0,0 +1,402 @@
# app/services/memory/integration/context_source.py
"""
Memory Context Source.
Provides agent memory as a context source for the Context Engine.
Retrieves relevant memories based on query and converts them to MemoryContext objects.
"""
import logging
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.context.types.memory import MemoryContext
from app.services.memory.episodic import EpisodicMemory
from app.services.memory.procedural import ProceduralMemory
from app.services.memory.semantic import SemanticMemory
from app.services.memory.working import WorkingMemory
logger = logging.getLogger(__name__)
@dataclass
class MemoryFetchConfig:
"""Configuration for memory fetching."""
# Limits per memory type
working_limit: int = 10
episodic_limit: int = 10
semantic_limit: int = 15
procedural_limit: int = 5
# Time ranges
episodic_days_back: int = 30
min_relevance: float = 0.3
# Which memory types to include
include_working: bool = True
include_episodic: bool = True
include_semantic: bool = True
include_procedural: bool = True
@dataclass
class MemoryFetchResult:
"""Result of memory fetch operation."""
contexts: list[MemoryContext]
by_type: dict[str, int]
fetch_time_ms: float
query: str
class MemoryContextSource:
"""
Source for memory context in the Context Engine.
This service retrieves relevant memories based on a query and
converts them to MemoryContext objects for context assembly.
It coordinates between all memory types (working, episodic,
semantic, procedural) to provide a comprehensive memory context.
"""
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize the memory context source.
Args:
session: Database session
embedding_generator: Optional embedding generator for semantic search
"""
self._session = session
self._embedding_generator = embedding_generator
# Lazy-initialized memory services
self._episodic: EpisodicMemory | None = None
self._semantic: SemanticMemory | None = None
self._procedural: ProceduralMemory | None = None
async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service."""
if self._episodic is None:
self._episodic = await EpisodicMemory.create(
self._session,
self._embedding_generator,
)
return self._episodic
async def _get_semantic(self) -> SemanticMemory:
"""Get or create semantic memory service."""
if self._semantic is None:
self._semantic = await SemanticMemory.create(
self._session,
self._embedding_generator,
)
return self._semantic
async def _get_procedural(self) -> ProceduralMemory:
"""Get or create procedural memory service."""
if self._procedural is None:
self._procedural = await ProceduralMemory.create(
self._session,
self._embedding_generator,
)
return self._procedural
async def fetch_context(
self,
query: str,
project_id: UUID,
agent_instance_id: UUID | None = None,
agent_type_id: UUID | None = None,
session_id: str | None = None,
config: MemoryFetchConfig | None = None,
) -> MemoryFetchResult:
"""
Fetch relevant memories as context.
This is the main entry point for the Context Engine integration.
It searches across all memory types and returns relevant memories
as MemoryContext objects.
Args:
query: Search query for finding relevant memories
project_id: Project scope
agent_instance_id: Optional agent instance scope
agent_type_id: Optional agent type scope (for procedural)
session_id: Optional session ID (for working memory)
config: Optional fetch configuration
Returns:
MemoryFetchResult with contexts and metadata
"""
config = config or MemoryFetchConfig()
start_time = datetime.now(UTC)
contexts: list[MemoryContext] = []
by_type: dict[str, int] = {
"working": 0,
"episodic": 0,
"semantic": 0,
"procedural": 0,
}
# Fetch from working memory (session-scoped)
if config.include_working and session_id:
try:
working_contexts = await self._fetch_working(
query=query,
session_id=session_id,
project_id=project_id,
agent_instance_id=agent_instance_id,
limit=config.working_limit,
)
contexts.extend(working_contexts)
by_type["working"] = len(working_contexts)
except Exception as e:
logger.warning(f"Failed to fetch working memory: {e}")
# Fetch from episodic memory
if config.include_episodic:
try:
episodic_contexts = await self._fetch_episodic(
query=query,
project_id=project_id,
agent_instance_id=agent_instance_id,
limit=config.episodic_limit,
days_back=config.episodic_days_back,
)
contexts.extend(episodic_contexts)
by_type["episodic"] = len(episodic_contexts)
except Exception as e:
logger.warning(f"Failed to fetch episodic memory: {e}")
# Fetch from semantic memory
if config.include_semantic:
try:
semantic_contexts = await self._fetch_semantic(
query=query,
project_id=project_id,
limit=config.semantic_limit,
min_relevance=config.min_relevance,
)
contexts.extend(semantic_contexts)
by_type["semantic"] = len(semantic_contexts)
except Exception as e:
logger.warning(f"Failed to fetch semantic memory: {e}")
# Fetch from procedural memory
if config.include_procedural:
try:
procedural_contexts = await self._fetch_procedural(
query=query,
project_id=project_id,
agent_type_id=agent_type_id,
limit=config.procedural_limit,
)
contexts.extend(procedural_contexts)
by_type["procedural"] = len(procedural_contexts)
except Exception as e:
logger.warning(f"Failed to fetch procedural memory: {e}")
# Sort by relevance
contexts.sort(key=lambda c: c.relevance_score, reverse=True)
fetch_time = (datetime.now(UTC) - start_time).total_seconds() * 1000
logger.debug(
f"Fetched {len(contexts)} memory contexts for query '{query[:50]}...' "
f"in {fetch_time:.1f}ms"
)
return MemoryFetchResult(
contexts=contexts,
by_type=by_type,
fetch_time_ms=fetch_time,
query=query,
)
async def _fetch_working(
self,
query: str,
session_id: str,
project_id: UUID,
agent_instance_id: UUID | None,
limit: int,
) -> list[MemoryContext]:
"""Fetch from working memory."""
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
contexts: list[MemoryContext] = []
all_keys = await working.list_keys()
# Filter keys by query (simple substring match)
query_lower = query.lower()
matched_keys = [k for k in all_keys if query_lower in k.lower()]
# If no query match, include all keys (working memory is always relevant)
if not matched_keys and query:
matched_keys = all_keys
for key in matched_keys[:limit]:
value = await working.get(key)
if value is not None:
contexts.append(
MemoryContext.from_working_memory(
key=key,
value=value,
source=f"working:{session_id}",
query=query,
)
)
return contexts
async def _fetch_episodic(
self,
query: str,
project_id: UUID,
agent_instance_id: UUID | None,
limit: int,
days_back: int,
) -> list[MemoryContext]:
"""Fetch from episodic memory."""
episodic = await self._get_episodic()
# Search for similar episodes
episodes = await episodic.search_similar(
project_id=project_id,
query=query,
limit=limit,
agent_instance_id=agent_instance_id,
)
# Also get recent episodes if we didn't find enough
if len(episodes) < limit // 2:
since = datetime.now(UTC) - timedelta(days=days_back)
recent = await episodic.get_recent(
project_id=project_id,
limit=limit,
since=since,
)
# Deduplicate by ID
existing_ids = {e.id for e in episodes}
for ep in recent:
if ep.id not in existing_ids:
episodes.append(ep)
if len(episodes) >= limit:
break
return [
MemoryContext.from_episodic_memory(ep, query=query)
for ep in episodes[:limit]
]
async def _fetch_semantic(
self,
query: str,
project_id: UUID,
limit: int,
min_relevance: float,
) -> list[MemoryContext]:
"""Fetch from semantic memory."""
semantic = await self._get_semantic()
facts = await semantic.search_facts(
query=query,
project_id=project_id,
limit=limit,
min_confidence=min_relevance,
)
return [
MemoryContext.from_semantic_memory(fact, query=query)
for fact in facts
]
async def _fetch_procedural(
self,
query: str,
project_id: UUID,
agent_type_id: UUID | None,
limit: int,
) -> list[MemoryContext]:
"""Fetch from procedural memory."""
procedural = await self._get_procedural()
procedures = await procedural.find_matching(
context=query,
project_id=project_id,
agent_type_id=agent_type_id,
limit=limit,
)
return [
MemoryContext.from_procedural_memory(proc, query=query)
for proc in procedures
]
async def fetch_all_working(
self,
session_id: str,
project_id: UUID,
agent_instance_id: UUID | None = None,
) -> list[MemoryContext]:
"""
Fetch all working memory for a session.
Useful for including entire session state in context.
Args:
session_id: Session ID
project_id: Project scope
agent_instance_id: Optional agent instance scope
Returns:
List of MemoryContext for all working memory items
"""
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
contexts: list[MemoryContext] = []
all_keys = await working.list_keys()
for key in all_keys:
value = await working.get(key)
if value is not None:
contexts.append(
MemoryContext.from_working_memory(
key=key,
value=value,
source=f"working:{session_id}",
)
)
return contexts
# Factory function
async def get_memory_context_source(
session: AsyncSession,
embedding_generator: Any | None = None,
) -> MemoryContextSource:
"""Create a memory context source instance."""
return MemoryContextSource(
session=session,
embedding_generator=embedding_generator,
)

View File

@@ -0,0 +1,629 @@
# app/services/memory/integration/lifecycle.py
"""
Agent Lifecycle Hooks for Memory System.
Provides memory management hooks for agent lifecycle events:
- spawn: Initialize working memory for new agent instance
- pause: Checkpoint working memory state
- resume: Restore working memory from checkpoint
- terminate: Consolidate session to episodic memory
"""
import logging
from collections.abc import Callable, Coroutine
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.memory.episodic import EpisodicMemory
from app.services.memory.types import EpisodeCreate, Outcome
from app.services.memory.working import WorkingMemory
logger = logging.getLogger(__name__)
@dataclass
class LifecycleEvent:
"""Event data for lifecycle hooks."""
event_type: str # spawn, pause, resume, terminate
project_id: UUID
agent_instance_id: UUID
agent_type_id: UUID | None = None
session_id: str | None = None
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class LifecycleResult:
"""Result of a lifecycle operation."""
success: bool
event_type: str
message: str | None = None
data: dict[str, Any] = field(default_factory=dict)
duration_ms: float = 0.0
# Type alias for lifecycle hooks
LifecycleHook = Callable[[LifecycleEvent], Coroutine[Any, Any, None]]
class LifecycleHooks:
"""
Collection of lifecycle hooks.
Allows registration of custom hooks for lifecycle events.
Hooks are called after the core memory operations.
"""
def __init__(self) -> None:
"""Initialize lifecycle hooks."""
self._spawn_hooks: list[LifecycleHook] = []
self._pause_hooks: list[LifecycleHook] = []
self._resume_hooks: list[LifecycleHook] = []
self._terminate_hooks: list[LifecycleHook] = []
def on_spawn(self, hook: LifecycleHook) -> LifecycleHook:
"""Register a spawn hook."""
self._spawn_hooks.append(hook)
return hook
def on_pause(self, hook: LifecycleHook) -> LifecycleHook:
"""Register a pause hook."""
self._pause_hooks.append(hook)
return hook
def on_resume(self, hook: LifecycleHook) -> LifecycleHook:
"""Register a resume hook."""
self._resume_hooks.append(hook)
return hook
def on_terminate(self, hook: LifecycleHook) -> LifecycleHook:
"""Register a terminate hook."""
self._terminate_hooks.append(hook)
return hook
async def run_spawn_hooks(self, event: LifecycleEvent) -> None:
"""Run all spawn hooks."""
for hook in self._spawn_hooks:
try:
await hook(event)
except Exception as e:
logger.warning(f"Spawn hook failed: {e}")
async def run_pause_hooks(self, event: LifecycleEvent) -> None:
"""Run all pause hooks."""
for hook in self._pause_hooks:
try:
await hook(event)
except Exception as e:
logger.warning(f"Pause hook failed: {e}")
async def run_resume_hooks(self, event: LifecycleEvent) -> None:
"""Run all resume hooks."""
for hook in self._resume_hooks:
try:
await hook(event)
except Exception as e:
logger.warning(f"Resume hook failed: {e}")
async def run_terminate_hooks(self, event: LifecycleEvent) -> None:
"""Run all terminate hooks."""
for hook in self._terminate_hooks:
try:
await hook(event)
except Exception as e:
logger.warning(f"Terminate hook failed: {e}")
class AgentLifecycleManager:
"""
Manager for agent lifecycle and memory integration.
Handles memory operations during agent lifecycle events:
- spawn: Creates new working memory for the session
- pause: Saves working memory state to checkpoint
- resume: Restores working memory from checkpoint
- terminate: Consolidates working memory to episodic memory
"""
# Key prefix for checkpoint storage
CHECKPOINT_PREFIX = "__checkpoint__"
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
hooks: LifecycleHooks | None = None,
) -> None:
"""
Initialize the lifecycle manager.
Args:
session: Database session
embedding_generator: Optional embedding generator
hooks: Optional lifecycle hooks
"""
self._session = session
self._embedding_generator = embedding_generator
self._hooks = hooks or LifecycleHooks()
# Lazy-initialized services
self._episodic: EpisodicMemory | None = None
async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service."""
if self._episodic is None:
self._episodic = await EpisodicMemory.create(
self._session,
self._embedding_generator,
)
return self._episodic
@property
def hooks(self) -> LifecycleHooks:
"""Get the lifecycle hooks."""
return self._hooks
async def spawn(
self,
project_id: UUID,
agent_instance_id: UUID,
session_id: str,
agent_type_id: UUID | None = None,
initial_state: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> LifecycleResult:
"""
Handle agent spawn - initialize working memory.
Creates a new working memory instance for the agent session
and optionally populates it with initial state.
Args:
project_id: Project scope
agent_instance_id: Agent instance ID
session_id: Session ID for working memory
agent_type_id: Optional agent type ID
initial_state: Optional initial state to populate
metadata: Optional metadata for the event
Returns:
LifecycleResult with spawn outcome
"""
start_time = datetime.now(UTC)
try:
# Create working memory for the session
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id),
)
# Populate initial state if provided
items_set = 0
if initial_state:
for key, value in initial_state.items():
await working.set(key, value)
items_set += 1
# Create and run event hooks
event = LifecycleEvent(
event_type="spawn",
project_id=project_id,
agent_instance_id=agent_instance_id,
agent_type_id=agent_type_id,
session_id=session_id,
metadata=metadata or {},
)
await self._hooks.run_spawn_hooks(event)
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
logger.info(
f"Agent {agent_instance_id} spawned with session {session_id}, "
f"initial state: {items_set} items"
)
return LifecycleResult(
success=True,
event_type="spawn",
message="Agent spawned successfully",
data={
"session_id": session_id,
"initial_items": items_set,
},
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Spawn failed for agent {agent_instance_id}: {e}")
return LifecycleResult(
success=False,
event_type="spawn",
message=f"Spawn failed: {e}",
)
async def pause(
self,
project_id: UUID,
agent_instance_id: UUID,
session_id: str,
checkpoint_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> LifecycleResult:
"""
Handle agent pause - checkpoint working memory.
Saves the current working memory state to a checkpoint
that can be restored later with resume().
Args:
project_id: Project scope
agent_instance_id: Agent instance ID
session_id: Session ID
checkpoint_id: Optional checkpoint identifier
metadata: Optional metadata for the event
Returns:
LifecycleResult with checkpoint data
"""
start_time = datetime.now(UTC)
checkpoint_id = checkpoint_id or f"checkpoint_{int(start_time.timestamp())}"
try:
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id),
)
# Get all current state
all_keys = await working.list_keys()
# Filter out checkpoint keys
state_keys = [k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)]
state: dict[str, Any] = {}
for key in state_keys:
value = await working.get(key)
if value is not None:
state[key] = value
# Store checkpoint
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
await working.set(
checkpoint_key,
{
"state": state,
"timestamp": start_time.isoformat(),
"keys_count": len(state),
},
ttl_seconds=86400 * 7, # Keep checkpoint for 7 days
)
# Run hooks
event = LifecycleEvent(
event_type="pause",
project_id=project_id,
agent_instance_id=agent_instance_id,
session_id=session_id,
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
)
await self._hooks.run_pause_hooks(event)
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
logger.info(
f"Agent {agent_instance_id} paused, checkpoint {checkpoint_id} "
f"saved with {len(state)} items"
)
return LifecycleResult(
success=True,
event_type="pause",
message="Agent paused successfully",
data={
"checkpoint_id": checkpoint_id,
"items_saved": len(state),
"timestamp": start_time.isoformat(),
},
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Pause failed for agent {agent_instance_id}: {e}")
return LifecycleResult(
success=False,
event_type="pause",
message=f"Pause failed: {e}",
)
async def resume(
self,
project_id: UUID,
agent_instance_id: UUID,
session_id: str,
checkpoint_id: str,
clear_current: bool = True,
metadata: dict[str, Any] | None = None,
) -> LifecycleResult:
"""
Handle agent resume - restore from checkpoint.
Restores working memory state from a previously saved checkpoint.
Args:
project_id: Project scope
agent_instance_id: Agent instance ID
session_id: Session ID
checkpoint_id: Checkpoint to restore from
clear_current: Whether to clear current state before restoring
metadata: Optional metadata for the event
Returns:
LifecycleResult with restore outcome
"""
start_time = datetime.now(UTC)
try:
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id),
)
# Get checkpoint
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
checkpoint = await working.get(checkpoint_key)
if checkpoint is None:
return LifecycleResult(
success=False,
event_type="resume",
message=f"Checkpoint '{checkpoint_id}' not found",
)
# Clear current state if requested
if clear_current:
all_keys = await working.list_keys()
for key in all_keys:
if not key.startswith(self.CHECKPOINT_PREFIX):
await working.delete(key)
# Restore state from checkpoint
state = checkpoint.get("state", {})
items_restored = 0
for key, value in state.items():
await working.set(key, value)
items_restored += 1
# Run hooks
event = LifecycleEvent(
event_type="resume",
project_id=project_id,
agent_instance_id=agent_instance_id,
session_id=session_id,
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
)
await self._hooks.run_resume_hooks(event)
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
logger.info(
f"Agent {agent_instance_id} resumed from checkpoint {checkpoint_id}, "
f"restored {items_restored} items"
)
return LifecycleResult(
success=True,
event_type="resume",
message="Agent resumed successfully",
data={
"checkpoint_id": checkpoint_id,
"items_restored": items_restored,
"checkpoint_timestamp": checkpoint.get("timestamp"),
},
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Resume failed for agent {agent_instance_id}: {e}")
return LifecycleResult(
success=False,
event_type="resume",
message=f"Resume failed: {e}",
)
async def terminate(
self,
project_id: UUID,
agent_instance_id: UUID,
session_id: str,
task_description: str | None = None,
outcome: Outcome = Outcome.SUCCESS,
lessons_learned: list[str] | None = None,
consolidate_to_episodic: bool = True,
cleanup_working: bool = True,
metadata: dict[str, Any] | None = None,
) -> LifecycleResult:
"""
Handle agent termination - consolidate to episodic memory.
Consolidates the session's working memory into an episodic memory
entry, then optionally cleans up the working memory.
Args:
project_id: Project scope
agent_instance_id: Agent instance ID
session_id: Session ID
task_description: Description of what was accomplished
outcome: Task outcome (SUCCESS, FAILURE, PARTIAL)
lessons_learned: Optional list of lessons learned
consolidate_to_episodic: Whether to create episodic entry
cleanup_working: Whether to clear working memory
metadata: Optional metadata for the event
Returns:
LifecycleResult with termination outcome
"""
start_time = datetime.now(UTC)
try:
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id),
)
# Gather session state for consolidation
all_keys = await working.list_keys()
state_keys = [k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)]
session_state: dict[str, Any] = {}
for key in state_keys:
value = await working.get(key)
if value is not None:
session_state[key] = value
episode_id: str | None = None
# Consolidate to episodic memory
if consolidate_to_episodic:
episodic = await self._get_episodic()
description = task_description or f"Session {session_id} completed"
episode_data = EpisodeCreate(
project_id=project_id,
agent_instance_id=agent_instance_id,
session_id=session_id,
task_type="session_completion",
task_description=description[:500],
outcome=outcome,
outcome_details=f"Session terminated with {len(session_state)} state items",
actions=[
{
"type": "session_terminate",
"state_keys": list(session_state.keys()),
"outcome": outcome.value,
}
],
context_summary=str(session_state)[:1000] if session_state else "",
lessons_learned=lessons_learned or [],
duration_seconds=0.0, # Unknown at this point
tokens_used=0,
importance_score=0.6, # Moderate importance for session ends
)
episode = await episodic.record_episode(episode_data)
episode_id = str(episode.id)
# Clean up working memory
items_cleared = 0
if cleanup_working:
for key in all_keys:
await working.delete(key)
items_cleared += 1
# Run hooks
event = LifecycleEvent(
event_type="terminate",
project_id=project_id,
agent_instance_id=agent_instance_id,
session_id=session_id,
metadata={**(metadata or {}), "episode_id": episode_id},
)
await self._hooks.run_terminate_hooks(event)
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
logger.info(
f"Agent {agent_instance_id} terminated, session {session_id} "
f"consolidated to episode {episode_id}"
)
return LifecycleResult(
success=True,
event_type="terminate",
message="Agent terminated successfully",
data={
"episode_id": episode_id,
"state_items_consolidated": len(session_state),
"items_cleared": items_cleared,
"outcome": outcome.value,
},
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"Terminate failed for agent {agent_instance_id}: {e}")
return LifecycleResult(
success=False,
event_type="terminate",
message=f"Terminate failed: {e}",
)
async def list_checkpoints(
self,
project_id: UUID,
agent_instance_id: UUID,
session_id: str,
) -> list[dict[str, Any]]:
"""
List available checkpoints for a session.
Args:
project_id: Project scope
agent_instance_id: Agent instance ID
session_id: Session ID
Returns:
List of checkpoint metadata dicts
"""
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id),
agent_instance_id=str(agent_instance_id),
)
all_keys = await working.list_keys()
checkpoints: list[dict[str, Any]] = []
for key in all_keys:
if key.startswith(self.CHECKPOINT_PREFIX):
checkpoint_id = key[len(self.CHECKPOINT_PREFIX):]
checkpoint = await working.get(key)
if checkpoint:
checkpoints.append({
"checkpoint_id": checkpoint_id,
"timestamp": checkpoint.get("timestamp"),
"keys_count": checkpoint.get("keys_count", 0),
})
# Sort by timestamp (newest first)
checkpoints.sort(
key=lambda c: c.get("timestamp", ""),
reverse=True,
)
return checkpoints
# Factory function
async def get_lifecycle_manager(
session: AsyncSession,
embedding_generator: Any | None = None,
hooks: LifecycleHooks | None = None,
) -> AgentLifecycleManager:
"""Create a lifecycle manager instance."""
return AgentLifecycleManager(
session=session,
embedding_generator=embedding_generator,
hooks=hooks,
)

View File

@@ -0,0 +1,262 @@
# tests/unit/services/context/types/test_memory.py
"""Tests for MemoryContext type."""
from datetime import UTC, datetime
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from app.services.context.types import ContextType
from app.services.context.types.memory import MemoryContext, MemorySubtype
class TestMemorySubtype:
"""Tests for MemorySubtype enum."""
def test_all_types_defined(self) -> None:
"""All memory subtypes should be defined."""
assert MemorySubtype.WORKING == "working"
assert MemorySubtype.EPISODIC == "episodic"
assert MemorySubtype.SEMANTIC == "semantic"
assert MemorySubtype.PROCEDURAL == "procedural"
def test_enum_values(self) -> None:
"""Enum values should match strings."""
assert MemorySubtype.WORKING.value == "working"
assert MemorySubtype("episodic") == MemorySubtype.EPISODIC
class TestMemoryContext:
"""Tests for MemoryContext class."""
def test_get_type_returns_memory(self) -> None:
"""get_type should return MEMORY."""
ctx = MemoryContext(content="test", source="test_source")
assert ctx.get_type() == ContextType.MEMORY
def test_default_values(self) -> None:
"""Default values should be set correctly."""
ctx = MemoryContext(content="test", source="test_source")
assert ctx.memory_subtype == MemorySubtype.EPISODIC
assert ctx.memory_id is None
assert ctx.relevance_score == 0.0
assert ctx.importance == 0.5
def test_to_dict_includes_memory_fields(self) -> None:
"""to_dict should include memory-specific fields."""
ctx = MemoryContext(
content="test content",
source="test_source",
memory_subtype=MemorySubtype.SEMANTIC,
memory_id="mem-123",
relevance_score=0.8,
subject="User",
predicate="prefers",
object_value="dark mode",
)
data = ctx.to_dict()
assert data["memory_subtype"] == "semantic"
assert data["memory_id"] == "mem-123"
assert data["relevance_score"] == 0.8
assert data["subject"] == "User"
assert data["predicate"] == "prefers"
assert data["object_value"] == "dark mode"
def test_from_dict(self) -> None:
"""from_dict should create correct MemoryContext."""
data = {
"content": "test content",
"source": "test_source",
"timestamp": "2024-01-01T00:00:00+00:00",
"memory_subtype": "semantic",
"memory_id": "mem-123",
"relevance_score": 0.8,
"subject": "Test",
}
ctx = MemoryContext.from_dict(data)
assert ctx.content == "test content"
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
assert ctx.memory_id == "mem-123"
assert ctx.subject == "Test"
class TestMemoryContextFromWorkingMemory:
"""Tests for MemoryContext.from_working_memory."""
def test_creates_working_memory_context(self) -> None:
"""Should create working memory context from key/value."""
ctx = MemoryContext.from_working_memory(
key="user_preferences",
value={"theme": "dark"},
source="working:sess-123",
query="preferences",
)
assert ctx.memory_subtype == MemorySubtype.WORKING
assert ctx.key == "user_preferences"
assert "{'theme': 'dark'}" in ctx.content
assert ctx.relevance_score == 1.0 # Working memory is always relevant
assert ctx.importance == 0.8 # Higher importance
def test_string_value(self) -> None:
"""Should handle string values."""
ctx = MemoryContext.from_working_memory(
key="current_task",
value="Build authentication",
)
assert ctx.content == "Build authentication"
class TestMemoryContextFromEpisodicMemory:
"""Tests for MemoryContext.from_episodic_memory."""
def test_creates_episodic_memory_context(self) -> None:
"""Should create episodic memory context from episode."""
episode = MagicMock()
episode.id = uuid4()
episode.task_description = "Implemented login feature"
episode.task_type = "feature_implementation"
episode.outcome = MagicMock(value="success")
episode.importance_score = 0.9
episode.session_id = "sess-123"
episode.occurred_at = datetime.now(UTC)
episode.lessons_learned = ["Use proper validation"]
ctx = MemoryContext.from_episodic_memory(episode, query="login")
assert ctx.memory_subtype == MemorySubtype.EPISODIC
assert ctx.memory_id == str(episode.id)
assert ctx.content == "Implemented login feature"
assert ctx.task_type == "feature_implementation"
assert ctx.outcome == "success"
assert ctx.importance == 0.9
def test_handles_missing_outcome(self) -> None:
"""Should handle episodes with no outcome."""
episode = MagicMock()
episode.id = uuid4()
episode.task_description = "WIP task"
episode.outcome = None
episode.importance_score = 0.5
episode.occurred_at = None
ctx = MemoryContext.from_episodic_memory(episode)
assert ctx.outcome is None
class TestMemoryContextFromSemanticMemory:
"""Tests for MemoryContext.from_semantic_memory."""
def test_creates_semantic_memory_context(self) -> None:
"""Should create semantic memory context from fact."""
fact = MagicMock()
fact.id = uuid4()
fact.subject = "User"
fact.predicate = "prefers"
fact.object = "dark mode"
fact.confidence = 0.95
ctx = MemoryContext.from_semantic_memory(fact, query="user preferences")
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
assert ctx.memory_id == str(fact.id)
assert ctx.content == "User prefers dark mode"
assert ctx.subject == "User"
assert ctx.predicate == "prefers"
assert ctx.object_value == "dark mode"
assert ctx.relevance_score == 0.95
class TestMemoryContextFromProceduralMemory:
"""Tests for MemoryContext.from_procedural_memory."""
def test_creates_procedural_memory_context(self) -> None:
"""Should create procedural memory context from procedure."""
procedure = MagicMock()
procedure.id = uuid4()
procedure.name = "Deploy to Production"
procedure.trigger_pattern = "When deploying to production"
procedure.steps = [
{"action": "run_tests"},
{"action": "build_docker"},
{"action": "deploy"},
]
procedure.success_rate = 0.85
procedure.success_count = 10
procedure.failure_count = 2
ctx = MemoryContext.from_procedural_memory(procedure, query="deploy")
assert ctx.memory_subtype == MemorySubtype.PROCEDURAL
assert ctx.memory_id == str(procedure.id)
assert "Deploy to Production" in ctx.content
assert "When deploying to production" in ctx.content
assert ctx.trigger == "When deploying to production"
assert ctx.success_rate == 0.85
assert ctx.metadata["steps_count"] == 3
assert ctx.metadata["execution_count"] == 12
class TestMemoryContextHelpers:
"""Tests for MemoryContext helper methods."""
def test_is_working_memory(self) -> None:
"""is_working_memory should return True for working memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.WORKING,
)
assert ctx.is_working_memory() is True
assert ctx.is_episodic_memory() is False
def test_is_episodic_memory(self) -> None:
"""is_episodic_memory should return True for episodic memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.EPISODIC,
)
assert ctx.is_episodic_memory() is True
assert ctx.is_semantic_memory() is False
def test_is_semantic_memory(self) -> None:
"""is_semantic_memory should return True for semantic memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.SEMANTIC,
)
assert ctx.is_semantic_memory() is True
assert ctx.is_procedural_memory() is False
def test_is_procedural_memory(self) -> None:
"""is_procedural_memory should return True for procedural memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.PROCEDURAL,
)
assert ctx.is_procedural_memory() is True
assert ctx.is_working_memory() is False
def test_get_formatted_source(self) -> None:
"""get_formatted_source should return formatted string."""
ctx = MemoryContext(
content="test",
source="episodic:12345678-1234-1234-1234-123456789012",
memory_subtype=MemorySubtype.EPISODIC,
memory_id="12345678-1234-1234-1234-123456789012",
)
formatted = ctx.get_formatted_source()
assert "[episodic]" in formatted
assert "12345678..." in formatted

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/integration/__init__.py
"""Tests for memory integration module."""

View File

@@ -0,0 +1,322 @@
# tests/unit/services/memory/integration/test_context_source.py
"""Tests for MemoryContextSource service."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from app.services.context.types.memory import MemorySubtype
from app.services.memory.integration.context_source import (
MemoryContextSource,
MemoryFetchConfig,
MemoryFetchResult,
get_memory_context_source,
)
pytestmark = pytest.mark.asyncio(loop_scope="function")
@pytest.fixture
def mock_session() -> MagicMock:
"""Create mock database session."""
return MagicMock()
@pytest.fixture
def context_source(mock_session: MagicMock) -> MemoryContextSource:
"""Create MemoryContextSource instance."""
return MemoryContextSource(session=mock_session)
class TestMemoryFetchConfig:
"""Tests for MemoryFetchConfig."""
def test_default_values(self) -> None:
"""Default config values should be set correctly."""
config = MemoryFetchConfig()
assert config.working_limit == 10
assert config.episodic_limit == 10
assert config.semantic_limit == 15
assert config.procedural_limit == 5
assert config.episodic_days_back == 30
assert config.min_relevance == 0.3
assert config.include_working is True
assert config.include_episodic is True
assert config.include_semantic is True
assert config.include_procedural is True
def test_custom_values(self) -> None:
"""Custom config values should be respected."""
config = MemoryFetchConfig(
working_limit=5,
include_working=False,
)
assert config.working_limit == 5
assert config.include_working is False
class TestMemoryFetchResult:
"""Tests for MemoryFetchResult."""
def test_stores_results(self) -> None:
"""Result should store contexts and metadata."""
result = MemoryFetchResult(
contexts=[],
by_type={"working": 0, "episodic": 5, "semantic": 3, "procedural": 0},
fetch_time_ms=15.5,
query="test query",
)
assert result.contexts == []
assert result.by_type["episodic"] == 5
assert result.fetch_time_ms == 15.5
assert result.query == "test query"
class TestMemoryContextSource:
"""Tests for MemoryContextSource service."""
async def test_fetch_context_empty_when_no_sources(
self,
context_source: MemoryContextSource,
) -> None:
"""fetch_context should return empty when all sources fail."""
config = MemoryFetchConfig(
include_working=False,
include_episodic=False,
include_semantic=False,
include_procedural=False,
)
result = await context_source.fetch_context(
query="test",
project_id=uuid4(),
config=config,
)
assert len(result.contexts) == 0
assert result.by_type == {
"working": 0,
"episodic": 0,
"semantic": 0,
"procedural": 0,
}
@patch("app.services.memory.integration.context_source.WorkingMemory")
async def test_fetch_working_memory(
self,
mock_working_cls: MagicMock,
context_source: MemoryContextSource,
) -> None:
"""Should fetch working memory when session_id provided."""
# Setup mock - both keys should match the query "task"
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["current_task", "task_state"])
mock_working.get = AsyncMock(side_effect=lambda k: {"key": k, "value": "test"})
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
config = MemoryFetchConfig(
include_episodic=False,
include_semantic=False,
include_procedural=False,
)
result = await context_source.fetch_context(
query="task", # Both keys contain "task"
project_id=uuid4(),
session_id="sess-123",
config=config,
)
assert result.by_type["working"] == 2
assert all(
c.memory_subtype == MemorySubtype.WORKING for c in result.contexts
)
@patch("app.services.memory.integration.context_source.EpisodicMemory")
async def test_fetch_episodic_memory(
self,
mock_episodic_cls: MagicMock,
context_source: MemoryContextSource,
) -> None:
"""Should fetch episodic memory."""
# Setup mock episode
mock_episode = MagicMock()
mock_episode.id = uuid4()
mock_episode.task_description = "Completed login feature"
mock_episode.task_type = "feature"
mock_episode.outcome = MagicMock(value="success")
mock_episode.importance_score = 0.8
mock_episode.occurred_at = datetime.now(UTC)
mock_episode.lessons_learned = []
mock_episodic = AsyncMock()
mock_episodic.search_similar = AsyncMock(return_value=[mock_episode])
mock_episodic.get_recent = AsyncMock(return_value=[])
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
config = MemoryFetchConfig(
include_working=False,
include_semantic=False,
include_procedural=False,
)
result = await context_source.fetch_context(
query="login",
project_id=uuid4(),
config=config,
)
assert result.by_type["episodic"] == 1
assert result.contexts[0].memory_subtype == MemorySubtype.EPISODIC
assert "Completed login feature" in result.contexts[0].content
@patch("app.services.memory.integration.context_source.SemanticMemory")
async def test_fetch_semantic_memory(
self,
mock_semantic_cls: MagicMock,
context_source: MemoryContextSource,
) -> None:
"""Should fetch semantic memory."""
# Setup mock fact
mock_fact = MagicMock()
mock_fact.id = uuid4()
mock_fact.subject = "User"
mock_fact.predicate = "prefers"
mock_fact.object = "dark mode"
mock_fact.confidence = 0.9
mock_semantic = AsyncMock()
mock_semantic.search_facts = AsyncMock(return_value=[mock_fact])
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
config = MemoryFetchConfig(
include_working=False,
include_episodic=False,
include_procedural=False,
)
result = await context_source.fetch_context(
query="preferences",
project_id=uuid4(),
config=config,
)
assert result.by_type["semantic"] == 1
assert result.contexts[0].memory_subtype == MemorySubtype.SEMANTIC
assert "User prefers dark mode" in result.contexts[0].content
@patch("app.services.memory.integration.context_source.ProceduralMemory")
async def test_fetch_procedural_memory(
self,
mock_procedural_cls: MagicMock,
context_source: MemoryContextSource,
) -> None:
"""Should fetch procedural memory."""
# Setup mock procedure
mock_proc = MagicMock()
mock_proc.id = uuid4()
mock_proc.name = "Deploy"
mock_proc.trigger_pattern = "When deploying"
mock_proc.steps = [{"action": "build"}, {"action": "test"}]
mock_proc.success_rate = 0.9
mock_proc.success_count = 9
mock_proc.failure_count = 1
mock_procedural = AsyncMock()
mock_procedural.find_matching = AsyncMock(return_value=[mock_proc])
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
config = MemoryFetchConfig(
include_working=False,
include_episodic=False,
include_semantic=False,
)
result = await context_source.fetch_context(
query="deploy",
project_id=uuid4(),
config=config,
)
assert result.by_type["procedural"] == 1
assert result.contexts[0].memory_subtype == MemorySubtype.PROCEDURAL
assert "Deploy" in result.contexts[0].content
async def test_results_sorted_by_relevance(
self,
context_source: MemoryContextSource,
) -> None:
"""Results should be sorted by relevance score."""
with patch.object(
context_source, "_fetch_episodic"
) as mock_ep, patch.object(
context_source, "_fetch_semantic"
) as mock_sem:
# Create contexts with different relevance scores
from app.services.context.types.memory import MemoryContext
ctx_low = MemoryContext(
content="low relevance",
source="test",
relevance_score=0.3,
)
ctx_high = MemoryContext(
content="high relevance",
source="test",
relevance_score=0.9,
)
mock_ep.return_value = [ctx_low]
mock_sem.return_value = [ctx_high]
config = MemoryFetchConfig(
include_working=False,
include_procedural=False,
)
result = await context_source.fetch_context(
query="test",
project_id=uuid4(),
config=config,
)
# Higher relevance should come first
assert result.contexts[0].relevance_score == 0.9
assert result.contexts[1].relevance_score == 0.3
@patch("app.services.memory.integration.context_source.WorkingMemory")
async def test_fetch_all_working(
self,
mock_working_cls: MagicMock,
context_source: MemoryContextSource,
) -> None:
"""fetch_all_working should return all working memory items."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["key1", "key2", "key3"])
mock_working.get = AsyncMock(return_value="value")
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
contexts = await context_source.fetch_all_working(
session_id="sess-123",
project_id=uuid4(),
)
assert len(contexts) == 3
assert all(c.memory_subtype == MemorySubtype.WORKING for c in contexts)
class TestGetMemoryContextSource:
"""Tests for factory function."""
async def test_creates_instance(self) -> None:
"""Factory should create MemoryContextSource instance."""
mock_session = MagicMock()
source = await get_memory_context_source(mock_session)
assert isinstance(source, MemoryContextSource)

View File

@@ -0,0 +1,471 @@
# tests/unit/services/memory/integration/test_lifecycle.py
"""Tests for Agent Lifecycle Hooks."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from app.services.memory.integration.lifecycle import (
AgentLifecycleManager,
LifecycleEvent,
LifecycleHooks,
LifecycleResult,
get_lifecycle_manager,
)
from app.services.memory.types import Outcome
pytestmark = pytest.mark.asyncio(loop_scope="function")
@pytest.fixture
def mock_session() -> MagicMock:
"""Create mock database session."""
return MagicMock()
@pytest.fixture
def lifecycle_hooks() -> LifecycleHooks:
"""Create lifecycle hooks instance."""
return LifecycleHooks()
@pytest.fixture
def lifecycle_manager(mock_session: MagicMock) -> AgentLifecycleManager:
"""Create lifecycle manager instance."""
return AgentLifecycleManager(session=mock_session)
class TestLifecycleEvent:
"""Tests for LifecycleEvent dataclass."""
def test_creates_event(self) -> None:
"""Should create event with required fields."""
project_id = uuid4()
agent_id = uuid4()
event = LifecycleEvent(
event_type="spawn",
project_id=project_id,
agent_instance_id=agent_id,
)
assert event.event_type == "spawn"
assert event.project_id == project_id
assert event.agent_instance_id == agent_id
assert event.timestamp is not None
assert event.metadata == {}
def test_with_optional_fields(self) -> None:
"""Should include optional fields."""
event = LifecycleEvent(
event_type="terminate",
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
metadata={"reason": "completed"},
)
assert event.session_id == "sess-123"
assert event.metadata["reason"] == "completed"
class TestLifecycleResult:
"""Tests for LifecycleResult dataclass."""
def test_success_result(self) -> None:
"""Should create success result."""
result = LifecycleResult(
success=True,
event_type="spawn",
message="Agent spawned",
data={"session_id": "sess-123"},
duration_ms=10.5,
)
assert result.success is True
assert result.event_type == "spawn"
assert result.data["session_id"] == "sess-123"
def test_failure_result(self) -> None:
"""Should create failure result."""
result = LifecycleResult(
success=False,
event_type="resume",
message="Checkpoint not found",
)
assert result.success is False
assert result.message == "Checkpoint not found"
class TestLifecycleHooks:
"""Tests for LifecycleHooks class."""
def test_register_spawn_hook(self, lifecycle_hooks: LifecycleHooks) -> None:
"""Should register spawn hook."""
async def my_hook(event: LifecycleEvent) -> None:
pass
result = lifecycle_hooks.on_spawn(my_hook)
assert result is my_hook
assert my_hook in lifecycle_hooks._spawn_hooks
def test_register_all_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
"""Should register hooks for all event types."""
hooks = [
lifecycle_hooks.on_spawn(AsyncMock()),
lifecycle_hooks.on_pause(AsyncMock()),
lifecycle_hooks.on_resume(AsyncMock()),
lifecycle_hooks.on_terminate(AsyncMock()),
]
assert len(lifecycle_hooks._spawn_hooks) == 1
assert len(lifecycle_hooks._pause_hooks) == 1
assert len(lifecycle_hooks._resume_hooks) == 1
assert len(lifecycle_hooks._terminate_hooks) == 1
async def test_run_spawn_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
"""Should run all spawn hooks."""
hook1 = AsyncMock()
hook2 = AsyncMock()
lifecycle_hooks.on_spawn(hook1)
lifecycle_hooks.on_spawn(hook2)
event = LifecycleEvent(
event_type="spawn",
project_id=uuid4(),
agent_instance_id=uuid4(),
)
await lifecycle_hooks.run_spawn_hooks(event)
hook1.assert_called_once_with(event)
hook2.assert_called_once_with(event)
async def test_hook_failure_doesnt_stop_others(
self, lifecycle_hooks: LifecycleHooks
) -> None:
"""Hook failure should not stop other hooks from running."""
hook1 = AsyncMock(side_effect=ValueError("Oops"))
hook2 = AsyncMock()
lifecycle_hooks.on_pause(hook1)
lifecycle_hooks.on_pause(hook2)
event = LifecycleEvent(
event_type="pause",
project_id=uuid4(),
agent_instance_id=uuid4(),
)
await lifecycle_hooks.run_pause_hooks(event)
# hook2 should still be called even though hook1 failed
hook2.assert_called_once()
class TestAgentLifecycleManagerSpawn:
"""Tests for AgentLifecycleManager.spawn."""
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_spawn_creates_working_memory(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Spawn should create working memory for session."""
mock_working = AsyncMock()
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.spawn(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
)
assert result.success is True
assert result.event_type == "spawn"
mock_working_cls.for_session.assert_called_once()
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_spawn_with_initial_state(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Spawn should populate initial state."""
mock_working = AsyncMock()
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.spawn(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
initial_state={"key1": "value1", "key2": "value2"},
)
assert result.success is True
assert result.data["initial_items"] == 2
assert mock_working.set.call_count == 2
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_spawn_runs_hooks(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Spawn should run registered hooks."""
mock_working = AsyncMock()
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
hook = AsyncMock()
lifecycle_manager.hooks.on_spawn(hook)
await lifecycle_manager.spawn(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
)
hook.assert_called_once()
class TestAgentLifecycleManagerPause:
"""Tests for AgentLifecycleManager.pause."""
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_pause_creates_checkpoint(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Pause should create checkpoint of working memory."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
mock_working.get = AsyncMock(return_value={"data": "test"})
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.pause(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
checkpoint_id="ckpt-001",
)
assert result.success is True
assert result.event_type == "pause"
assert result.data["checkpoint_id"] == "ckpt-001"
assert result.data["items_saved"] == 2
# Should save checkpoint with state
mock_working.set.assert_called_once()
call_args = mock_working.set.call_args
# Check positional arg (first arg is key)
assert "__checkpoint__ckpt-001" in call_args[0][0]
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_pause_generates_checkpoint_id(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Pause should generate checkpoint ID if not provided."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=[])
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.pause(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
)
assert result.success is True
assert "checkpoint_id" in result.data
assert result.data["checkpoint_id"].startswith("checkpoint_")
class TestAgentLifecycleManagerResume:
"""Tests for AgentLifecycleManager.resume."""
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_resume_restores_checkpoint(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Resume should restore working memory from checkpoint."""
checkpoint_data = {
"state": {"key1": "value1", "key2": "value2"},
"timestamp": datetime.now(UTC).isoformat(),
"keys_count": 2,
}
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=[])
mock_working.get = AsyncMock(return_value=checkpoint_data)
mock_working.set = AsyncMock()
mock_working.delete = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.resume(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
checkpoint_id="ckpt-001",
)
assert result.success is True
assert result.event_type == "resume"
assert result.data["items_restored"] == 2
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_resume_checkpoint_not_found(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Resume should fail if checkpoint not found."""
mock_working = AsyncMock()
mock_working.get = AsyncMock(return_value=None)
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.resume(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
checkpoint_id="nonexistent",
)
assert result.success is False
assert "not found" in result.message.lower()
class TestAgentLifecycleManagerTerminate:
"""Tests for AgentLifecycleManager.terminate."""
@patch("app.services.memory.integration.lifecycle.EpisodicMemory")
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_terminate_consolidates_to_episodic(
self,
mock_working_cls: MagicMock,
mock_episodic_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Terminate should consolidate working memory to episodic."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
mock_working.get = AsyncMock(return_value="value")
mock_working.delete = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
mock_episode = MagicMock()
mock_episode.id = uuid4()
mock_episodic = AsyncMock()
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
result = await lifecycle_manager.terminate(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
task_description="Completed task",
outcome=Outcome.SUCCESS,
)
assert result.success is True
assert result.event_type == "terminate"
assert result.data["episode_id"] == str(mock_episode.id)
assert result.data["state_items_consolidated"] == 2
mock_episodic.record_episode.assert_called_once()
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_terminate_cleans_up_working(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Terminate should clean up working memory."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
mock_working.get = AsyncMock(return_value="value")
mock_working.delete = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.terminate(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
consolidate_to_episodic=False,
cleanup_working=True,
)
assert result.success is True
assert result.data["items_cleared"] == 2
assert mock_working.delete.call_count == 2
class TestAgentLifecycleManagerListCheckpoints:
"""Tests for AgentLifecycleManager.list_checkpoints."""
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_list_checkpoints(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Should list available checkpoints."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(
return_value=[
"__checkpoint__ckpt-001",
"__checkpoint__ckpt-002",
"regular_key",
]
)
mock_working.get = AsyncMock(
return_value={
"timestamp": "2024-01-01T00:00:00Z",
"keys_count": 5,
}
)
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
checkpoints = await lifecycle_manager.list_checkpoints(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
)
assert len(checkpoints) == 2
assert checkpoints[0]["checkpoint_id"] == "ckpt-001"
assert checkpoints[0]["keys_count"] == 5
class TestGetLifecycleManager:
"""Tests for factory function."""
async def test_creates_instance(self) -> None:
"""Factory should create AgentLifecycleManager instance."""
mock_session = MagicMock()
manager = await get_lifecycle_manager(mock_session)
assert isinstance(manager, AgentLifecycleManager)
async def test_with_custom_hooks(self) -> None:
"""Factory should accept custom hooks."""
mock_session = MagicMock()
custom_hooks = LifecycleHooks()
manager = await get_lifecycle_manager(mock_session, hooks=custom_hooks)
assert manager.hooks is custom_hooks