forked from cardosofelipe/fast-next-template
feat(memory): integrate memory system with context engine (#97)
## Changes ### New Context Type - Add MEMORY to ContextType enum for agent memory context - Create MemoryContext class with subtypes (working, episodic, semantic, procedural) - Factory methods: from_working_memory, from_episodic_memory, from_semantic_memory, from_procedural_memory ### Memory Context Source - MemoryContextSource service fetches relevant memories for context assembly - Configurable fetch limits per memory type - Parallel fetching from all memory types ### Agent Lifecycle Hooks - AgentLifecycleManager handles spawn, pause, resume, terminate events - spawn: Initialize working memory with optional initial state - pause: Create checkpoint of working memory - resume: Restore from checkpoint - terminate: Consolidate working memory to episodic memory - LifecycleHooks for custom extension points ### Context Engine Integration - Add memory_query parameter to assemble_context() - Add session_id and agent_type_id for memory scoping - Memory budget allocation (15% by default) - set_memory_source() for runtime configuration ### Tests - 48 new tests for MemoryContext, MemoryContextSource, and lifecycle hooks - All 108 memory-related tests passing - mypy and ruff checks passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -114,6 +114,8 @@ from .types import (
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MemoryContext,
|
||||
MemorySubtype,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
@@ -149,6 +151,8 @@ __all__ = [
|
||||
"FormattingError",
|
||||
"InvalidContextError",
|
||||
"KnowledgeContext",
|
||||
"MemoryContext",
|
||||
"MemorySubtype",
|
||||
"MessageRole",
|
||||
"ModelAdapter",
|
||||
"OpenAIAdapter",
|
||||
|
||||
@@ -30,6 +30,7 @@ class TokenBudget:
|
||||
knowledge: int = 0
|
||||
conversation: int = 0
|
||||
tools: int = 0
|
||||
memory: int = 0 # Agent memory (working, episodic, semantic, procedural)
|
||||
response_reserve: int = 0
|
||||
buffer: int = 0
|
||||
|
||||
@@ -60,6 +61,7 @@ class TokenBudget:
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tool": self.tools,
|
||||
"memory": self.memory,
|
||||
}
|
||||
return allocation_map.get(context_type, 0)
|
||||
|
||||
@@ -211,6 +213,7 @@ class TokenBudget:
|
||||
"knowledge": self.knowledge,
|
||||
"conversation": self.conversation,
|
||||
"tools": self.tools,
|
||||
"memory": self.memory,
|
||||
"response_reserve": self.response_reserve,
|
||||
"buffer": self.buffer,
|
||||
},
|
||||
@@ -264,9 +267,10 @@ class BudgetAllocator:
|
||||
total=total_tokens,
|
||||
system=int(total_tokens * alloc.get("system", 0.05)),
|
||||
task=int(total_tokens * alloc.get("task", 0.10)),
|
||||
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
|
||||
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
|
||||
knowledge=int(total_tokens * alloc.get("knowledge", 0.30)),
|
||||
conversation=int(total_tokens * alloc.get("conversation", 0.15)),
|
||||
tools=int(total_tokens * alloc.get("tools", 0.05)),
|
||||
memory=int(total_tokens * alloc.get("memory", 0.15)),
|
||||
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
|
||||
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
|
||||
)
|
||||
@@ -317,6 +321,8 @@ class BudgetAllocator:
|
||||
budget.conversation = max(0, budget.conversation + actual_adjustment)
|
||||
elif context_type == "tool":
|
||||
budget.tools = max(0, budget.tools + actual_adjustment)
|
||||
elif context_type == "memory":
|
||||
budget.memory = max(0, budget.memory + actual_adjustment)
|
||||
|
||||
return budget
|
||||
|
||||
@@ -338,7 +344,7 @@ class BudgetAllocator:
|
||||
Rebalanced budget
|
||||
"""
|
||||
if prioritize is None:
|
||||
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
|
||||
prioritize = [ContextType.KNOWLEDGE, ContextType.MEMORY, ContextType.TASK, ContextType.SYSTEM]
|
||||
|
||||
# Calculate unused tokens per type
|
||||
unused: dict[str, int] = {}
|
||||
|
||||
@@ -7,6 +7,7 @@ Provides a high-level API for assembling optimized context for LLM requests.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from .assembly import ContextPipeline
|
||||
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||
@@ -20,6 +21,7 @@ from .types import (
|
||||
BaseContext,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MemoryContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskContext,
|
||||
@@ -30,6 +32,7 @@ if TYPE_CHECKING:
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from app.services.mcp.client_manager import MCPClientManager
|
||||
from app.services.memory.integration import MemoryContextSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,6 +67,7 @@ class ContextEngine:
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
memory_source: "MemoryContextSource | None" = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the context engine.
|
||||
@@ -72,9 +76,11 @@ class ContextEngine:
|
||||
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
||||
redis: Redis connection for caching
|
||||
settings: Context settings
|
||||
memory_source: Optional memory context source for agent memory
|
||||
"""
|
||||
self._mcp = mcp_manager
|
||||
self._settings = settings or get_context_settings()
|
||||
self._memory_source = memory_source
|
||||
|
||||
# Initialize components
|
||||
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||
@@ -115,6 +121,15 @@ class ContextEngine:
|
||||
"""
|
||||
self._cache.set_redis(redis)
|
||||
|
||||
def set_memory_source(self, memory_source: "MemoryContextSource") -> None:
|
||||
"""
|
||||
Set memory context source for agent memory integration.
|
||||
|
||||
Args:
|
||||
memory_source: Memory context source
|
||||
"""
|
||||
self._memory_source = memory_source
|
||||
|
||||
async def assemble_context(
|
||||
self,
|
||||
project_id: str,
|
||||
@@ -126,6 +141,10 @@ class ContextEngine:
|
||||
task_description: str | None = None,
|
||||
knowledge_query: str | None = None,
|
||||
knowledge_limit: int = 10,
|
||||
memory_query: str | None = None,
|
||||
memory_limit: int = 20,
|
||||
session_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
conversation_history: list[dict[str, str]] | None = None,
|
||||
tool_results: list[dict[str, Any]] | None = None,
|
||||
custom_contexts: list[BaseContext] | None = None,
|
||||
@@ -151,6 +170,10 @@ class ContextEngine:
|
||||
task_description: Current task description
|
||||
knowledge_query: Query for knowledge base search
|
||||
knowledge_limit: Max number of knowledge results
|
||||
memory_query: Query for agent memory search
|
||||
memory_limit: Max number of memory results
|
||||
session_id: Session ID for working memory access
|
||||
agent_type_id: Agent type ID for procedural memory
|
||||
conversation_history: List of {"role": str, "content": str}
|
||||
tool_results: List of tool results to include
|
||||
custom_contexts: Additional custom contexts
|
||||
@@ -197,15 +220,27 @@ class ContextEngine:
|
||||
)
|
||||
contexts.extend(knowledge_contexts)
|
||||
|
||||
# 4. Conversation history
|
||||
# 4. Memory context from Agent Memory System
|
||||
if memory_query and self._memory_source:
|
||||
memory_contexts = await self._fetch_memory(
|
||||
project_id=project_id,
|
||||
agent_id=agent_id,
|
||||
query=memory_query,
|
||||
limit=memory_limit,
|
||||
session_id=session_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
contexts.extend(memory_contexts)
|
||||
|
||||
# 5. Conversation history
|
||||
if conversation_history:
|
||||
contexts.extend(self._convert_conversation(conversation_history))
|
||||
|
||||
# 5. Tool results
|
||||
# 6. Tool results
|
||||
if tool_results:
|
||||
contexts.extend(self._convert_tool_results(tool_results))
|
||||
|
||||
# 6. Custom contexts
|
||||
# 7. Custom contexts
|
||||
if custom_contexts:
|
||||
contexts.extend(custom_contexts)
|
||||
|
||||
@@ -308,6 +343,65 @@ class ContextEngine:
|
||||
logger.warning(f"Failed to fetch knowledge: {e}")
|
||||
return []
|
||||
|
||||
async def _fetch_memory(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_id: str,
|
||||
query: str,
|
||||
limit: int = 20,
|
||||
session_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
) -> list[MemoryContext]:
|
||||
"""
|
||||
Fetch relevant memories from Agent Memory System.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
agent_id: Agent identifier
|
||||
query: Search query
|
||||
limit: Maximum results
|
||||
session_id: Session ID for working memory
|
||||
agent_type_id: Agent type ID for procedural memory
|
||||
|
||||
Returns:
|
||||
List of MemoryContext instances
|
||||
"""
|
||||
if not self._memory_source:
|
||||
return []
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
|
||||
# Configure fetch limits
|
||||
from app.services.memory.integration.context_source import MemoryFetchConfig
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
working_limit=min(limit // 4, 5),
|
||||
episodic_limit=min(limit // 2, 10),
|
||||
semantic_limit=min(limit // 2, 10),
|
||||
procedural_limit=min(limit // 4, 5),
|
||||
include_working=session_id is not None,
|
||||
)
|
||||
|
||||
result = await self._memory_source.fetch_context(
|
||||
query=query,
|
||||
project_id=UUID(project_id),
|
||||
agent_instance_id=UUID(agent_id) if agent_id else None,
|
||||
agent_type_id=UUID(agent_type_id) if agent_type_id else None,
|
||||
session_id=session_id,
|
||||
config=config,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(result.contexts)} memory contexts for query: {query}, "
|
||||
f"by_type: {result.by_type}"
|
||||
)
|
||||
return result.contexts[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch memory: {e}")
|
||||
return []
|
||||
|
||||
def _convert_conversation(
|
||||
self,
|
||||
history: list[dict[str, str]],
|
||||
@@ -466,6 +560,7 @@ def create_context_engine(
|
||||
mcp_manager: "MCPClientManager | None" = None,
|
||||
redis: "Redis | None" = None,
|
||||
settings: ContextSettings | None = None,
|
||||
memory_source: "MemoryContextSource | None" = None,
|
||||
) -> ContextEngine:
|
||||
"""
|
||||
Create a context engine instance.
|
||||
@@ -474,6 +569,7 @@ def create_context_engine(
|
||||
mcp_manager: MCP client manager
|
||||
redis: Redis connection
|
||||
settings: Context settings
|
||||
memory_source: Optional memory context source
|
||||
|
||||
Returns:
|
||||
Configured ContextEngine instance
|
||||
@@ -482,4 +578,5 @@ def create_context_engine(
|
||||
mcp_manager=mcp_manager,
|
||||
redis=redis,
|
||||
settings=settings,
|
||||
memory_source=memory_source,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,10 @@ from .conversation import (
|
||||
MessageRole,
|
||||
)
|
||||
from .knowledge import KnowledgeContext
|
||||
from .memory import (
|
||||
MemoryContext,
|
||||
MemorySubtype,
|
||||
)
|
||||
from .system import SystemContext
|
||||
from .task import (
|
||||
TaskComplexity,
|
||||
@@ -33,6 +37,8 @@ __all__ = [
|
||||
"ContextType",
|
||||
"ConversationContext",
|
||||
"KnowledgeContext",
|
||||
"MemoryContext",
|
||||
"MemorySubtype",
|
||||
"MessageRole",
|
||||
"SystemContext",
|
||||
"TaskComplexity",
|
||||
|
||||
@@ -26,6 +26,7 @@ class ContextType(str, Enum):
|
||||
KNOWLEDGE = "knowledge"
|
||||
CONVERSATION = "conversation"
|
||||
TOOL = "tool"
|
||||
MEMORY = "memory" # Agent memory (working, episodic, semantic, procedural)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "ContextType":
|
||||
|
||||
282
backend/app/services/context/types/memory.py
Normal file
282
backend/app/services/context/types/memory.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Memory Context Type.
|
||||
|
||||
Represents agent memory as context for LLM requests.
|
||||
Includes working, episodic, semantic, and procedural memories.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class MemorySubtype(str, Enum):
|
||||
"""Types of agent memory."""
|
||||
|
||||
WORKING = "working" # Session-scoped temporary data
|
||||
EPISODIC = "episodic" # Task history and outcomes
|
||||
SEMANTIC = "semantic" # Facts and knowledge
|
||||
PROCEDURAL = "procedural" # Learned procedures
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MemoryContext(BaseContext):
|
||||
"""
|
||||
Context from agent memory system.
|
||||
|
||||
Memory context represents data retrieved from the agent
|
||||
memory system, including:
|
||||
- Working memory: Current session state
|
||||
- Episodic memory: Past task experiences
|
||||
- Semantic memory: Learned facts and knowledge
|
||||
- Procedural memory: Known procedures and workflows
|
||||
|
||||
Each memory item includes relevance scoring from search.
|
||||
"""
|
||||
|
||||
# Memory-specific fields
|
||||
memory_subtype: MemorySubtype = field(default=MemorySubtype.EPISODIC)
|
||||
memory_id: str | None = field(default=None)
|
||||
relevance_score: float = field(default=0.0)
|
||||
importance: float = field(default=0.5)
|
||||
search_query: str = field(default="")
|
||||
|
||||
# Type-specific fields (populated based on memory_subtype)
|
||||
key: str | None = field(default=None) # For working memory
|
||||
task_type: str | None = field(default=None) # For episodic
|
||||
outcome: str | None = field(default=None) # For episodic
|
||||
subject: str | None = field(default=None) # For semantic
|
||||
predicate: str | None = field(default=None) # For semantic
|
||||
object_value: str | None = field(default=None) # For semantic
|
||||
trigger: str | None = field(default=None) # For procedural
|
||||
success_rate: float | None = field(default=None) # For procedural
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return MEMORY context type."""
|
||||
return ContextType.MEMORY
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with memory-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"memory_subtype": self.memory_subtype.value,
|
||||
"memory_id": self.memory_id,
|
||||
"relevance_score": self.relevance_score,
|
||||
"importance": self.importance,
|
||||
"search_query": self.search_query,
|
||||
"key": self.key,
|
||||
"task_type": self.task_type,
|
||||
"outcome": self.outcome,
|
||||
"subject": self.subject,
|
||||
"predicate": self.predicate,
|
||||
"object_value": self.object_value,
|
||||
"trigger": self.trigger,
|
||||
"success_rate": self.success_rate,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "MemoryContext":
|
||||
"""Create MemoryContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
memory_subtype=MemorySubtype(data.get("memory_subtype", "episodic")),
|
||||
memory_id=data.get("memory_id"),
|
||||
relevance_score=data.get("relevance_score", 0.0),
|
||||
importance=data.get("importance", 0.5),
|
||||
search_query=data.get("search_query", ""),
|
||||
key=data.get("key"),
|
||||
task_type=data.get("task_type"),
|
||||
outcome=data.get("outcome"),
|
||||
subject=data.get("subject"),
|
||||
predicate=data.get("predicate"),
|
||||
object_value=data.get("object_value"),
|
||||
trigger=data.get("trigger"),
|
||||
success_rate=data.get("success_rate"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_working_memory(
|
||||
cls,
|
||||
key: str,
|
||||
value: Any,
|
||||
source: str = "working_memory",
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from working memory entry.
|
||||
|
||||
Args:
|
||||
key: Working memory key
|
||||
value: Value stored at key
|
||||
source: Source identifier
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
return cls(
|
||||
content=str(value),
|
||||
source=source,
|
||||
memory_subtype=MemorySubtype.WORKING,
|
||||
key=key,
|
||||
relevance_score=1.0, # Working memory is always relevant
|
||||
importance=0.8, # Higher importance for current session state
|
||||
search_query=query,
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_episodic_memory(
|
||||
cls,
|
||||
episode: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from episodic memory episode.
|
||||
|
||||
Args:
|
||||
episode: Episode object from episodic memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
outcome_val = None
|
||||
if hasattr(episode, "outcome") and episode.outcome:
|
||||
outcome_val = (
|
||||
episode.outcome.value
|
||||
if hasattr(episode.outcome, "value")
|
||||
else str(episode.outcome)
|
||||
)
|
||||
|
||||
return cls(
|
||||
content=episode.task_description,
|
||||
source=f"episodic:{episode.id}",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
memory_id=str(episode.id),
|
||||
relevance_score=getattr(episode, "importance_score", 0.5),
|
||||
importance=getattr(episode, "importance_score", 0.5),
|
||||
search_query=query,
|
||||
task_type=getattr(episode, "task_type", None),
|
||||
outcome=outcome_val,
|
||||
metadata={
|
||||
"session_id": getattr(episode, "session_id", None),
|
||||
"occurred_at": episode.occurred_at.isoformat()
|
||||
if hasattr(episode, "occurred_at") and episode.occurred_at
|
||||
else None,
|
||||
"lessons_learned": getattr(episode, "lessons_learned", []),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_semantic_memory(
|
||||
cls,
|
||||
fact: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from semantic memory fact.
|
||||
|
||||
Args:
|
||||
fact: Fact object from semantic memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
triple = f"{fact.subject} {fact.predicate} {fact.object}"
|
||||
return cls(
|
||||
content=triple,
|
||||
source=f"semantic:{fact.id}",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
memory_id=str(fact.id),
|
||||
relevance_score=getattr(fact, "confidence", 0.5),
|
||||
importance=getattr(fact, "confidence", 0.5),
|
||||
search_query=query,
|
||||
subject=fact.subject,
|
||||
predicate=fact.predicate,
|
||||
object_value=fact.object,
|
||||
priority=ContextPriority.NORMAL.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_procedural_memory(
|
||||
cls,
|
||||
procedure: Any,
|
||||
query: str = "",
|
||||
) -> "MemoryContext":
|
||||
"""
|
||||
Create MemoryContext from procedural memory procedure.
|
||||
|
||||
Args:
|
||||
procedure: Procedure object from procedural memory
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
MemoryContext instance
|
||||
"""
|
||||
# Format steps as content
|
||||
steps = getattr(procedure, "steps", [])
|
||||
steps_content = "\n".join(
|
||||
f" {i + 1}. {step.get('action', step) if isinstance(step, dict) else step}"
|
||||
for i, step in enumerate(steps)
|
||||
)
|
||||
content = f"Procedure: {procedure.name}\nTrigger: {procedure.trigger_pattern}\nSteps:\n{steps_content}"
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=f"procedural:{procedure.id}",
|
||||
memory_subtype=MemorySubtype.PROCEDURAL,
|
||||
memory_id=str(procedure.id),
|
||||
relevance_score=getattr(procedure, "success_rate", 0.5),
|
||||
importance=0.7, # Procedures are moderately important
|
||||
search_query=query,
|
||||
trigger=procedure.trigger_pattern,
|
||||
success_rate=getattr(procedure, "success_rate", None),
|
||||
metadata={
|
||||
"steps_count": len(steps),
|
||||
"execution_count": getattr(procedure, "success_count", 0)
|
||||
+ getattr(procedure, "failure_count", 0),
|
||||
},
|
||||
)
|
||||
|
||||
def is_working_memory(self) -> bool:
|
||||
"""Check if this is working memory."""
|
||||
return self.memory_subtype == MemorySubtype.WORKING
|
||||
|
||||
def is_episodic_memory(self) -> bool:
|
||||
"""Check if this is episodic memory."""
|
||||
return self.memory_subtype == MemorySubtype.EPISODIC
|
||||
|
||||
def is_semantic_memory(self) -> bool:
|
||||
"""Check if this is semantic memory."""
|
||||
return self.memory_subtype == MemorySubtype.SEMANTIC
|
||||
|
||||
def is_procedural_memory(self) -> bool:
|
||||
"""Check if this is procedural memory."""
|
||||
return self.memory_subtype == MemorySubtype.PROCEDURAL
|
||||
|
||||
def get_formatted_source(self) -> str:
|
||||
"""
|
||||
Get a formatted source string for display.
|
||||
|
||||
Returns:
|
||||
Formatted source string
|
||||
"""
|
||||
parts = [f"[{self.memory_subtype.value}]", self.source]
|
||||
if self.memory_id:
|
||||
parts.append(f"({self.memory_id[:8]}...)")
|
||||
return " ".join(parts)
|
||||
19
backend/app/services/memory/integration/__init__.py
Normal file
19
backend/app/services/memory/integration/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# app/services/memory/integration/__init__.py
|
||||
"""
|
||||
Memory Integration Module.
|
||||
|
||||
Provides integration between the agent memory system and other Syndarix components:
|
||||
- Context Engine: Memory as context source
|
||||
- Agent Lifecycle: Spawn, pause, resume, terminate hooks
|
||||
"""
|
||||
|
||||
from .context_source import MemoryContextSource, get_memory_context_source
|
||||
from .lifecycle import AgentLifecycleManager, LifecycleHooks, get_lifecycle_manager
|
||||
|
||||
__all__ = [
|
||||
"AgentLifecycleManager",
|
||||
"LifecycleHooks",
|
||||
"MemoryContextSource",
|
||||
"get_lifecycle_manager",
|
||||
"get_memory_context_source",
|
||||
]
|
||||
402
backend/app/services/memory/integration/context_source.py
Normal file
402
backend/app/services/memory/integration/context_source.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# app/services/memory/integration/context_source.py
|
||||
"""
|
||||
Memory Context Source.
|
||||
|
||||
Provides agent memory as a context source for the Context Engine.
|
||||
Retrieves relevant memories based on query and converts them to MemoryContext objects.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.context.types.memory import MemoryContext
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.procedural import ProceduralMemory
|
||||
from app.services.memory.semantic import SemanticMemory
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchConfig:
|
||||
"""Configuration for memory fetching."""
|
||||
|
||||
# Limits per memory type
|
||||
working_limit: int = 10
|
||||
episodic_limit: int = 10
|
||||
semantic_limit: int = 15
|
||||
procedural_limit: int = 5
|
||||
|
||||
# Time ranges
|
||||
episodic_days_back: int = 30
|
||||
min_relevance: float = 0.3
|
||||
|
||||
# Which memory types to include
|
||||
include_working: bool = True
|
||||
include_episodic: bool = True
|
||||
include_semantic: bool = True
|
||||
include_procedural: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchResult:
|
||||
"""Result of memory fetch operation."""
|
||||
|
||||
contexts: list[MemoryContext]
|
||||
by_type: dict[str, int]
|
||||
fetch_time_ms: float
|
||||
query: str
|
||||
|
||||
|
||||
class MemoryContextSource:
|
||||
"""
|
||||
Source for memory context in the Context Engine.
|
||||
|
||||
This service retrieves relevant memories based on a query and
|
||||
converts them to MemoryContext objects for context assembly.
|
||||
It coordinates between all memory types (working, episodic,
|
||||
semantic, procedural) to provide a comprehensive memory context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the memory context source.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
|
||||
# Lazy-initialized memory services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
async def fetch_context(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
config: MemoryFetchConfig | None = None,
|
||||
) -> MemoryFetchResult:
|
||||
"""
|
||||
Fetch relevant memories as context.
|
||||
|
||||
This is the main entry point for the Context Engine integration.
|
||||
It searches across all memory types and returns relevant memories
|
||||
as MemoryContext objects.
|
||||
|
||||
Args:
|
||||
query: Search query for finding relevant memories
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
agent_type_id: Optional agent type scope (for procedural)
|
||||
session_id: Optional session ID (for working memory)
|
||||
config: Optional fetch configuration
|
||||
|
||||
Returns:
|
||||
MemoryFetchResult with contexts and metadata
|
||||
"""
|
||||
config = config or MemoryFetchConfig()
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
by_type: dict[str, int] = {
|
||||
"working": 0,
|
||||
"episodic": 0,
|
||||
"semantic": 0,
|
||||
"procedural": 0,
|
||||
}
|
||||
|
||||
# Fetch from working memory (session-scoped)
|
||||
if config.include_working and session_id:
|
||||
try:
|
||||
working_contexts = await self._fetch_working(
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.working_limit,
|
||||
)
|
||||
contexts.extend(working_contexts)
|
||||
by_type["working"] = len(working_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch working memory: {e}")
|
||||
|
||||
# Fetch from episodic memory
|
||||
if config.include_episodic:
|
||||
try:
|
||||
episodic_contexts = await self._fetch_episodic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.episodic_limit,
|
||||
days_back=config.episodic_days_back,
|
||||
)
|
||||
contexts.extend(episodic_contexts)
|
||||
by_type["episodic"] = len(episodic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch episodic memory: {e}")
|
||||
|
||||
# Fetch from semantic memory
|
||||
if config.include_semantic:
|
||||
try:
|
||||
semantic_contexts = await self._fetch_semantic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=config.semantic_limit,
|
||||
min_relevance=config.min_relevance,
|
||||
)
|
||||
contexts.extend(semantic_contexts)
|
||||
by_type["semantic"] = len(semantic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch semantic memory: {e}")
|
||||
|
||||
# Fetch from procedural memory
|
||||
if config.include_procedural:
|
||||
try:
|
||||
procedural_contexts = await self._fetch_procedural(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=config.procedural_limit,
|
||||
)
|
||||
contexts.extend(procedural_contexts)
|
||||
by_type["procedural"] = len(procedural_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch procedural memory: {e}")
|
||||
|
||||
# Sort by relevance
|
||||
contexts.sort(key=lambda c: c.relevance_score, reverse=True)
|
||||
|
||||
fetch_time = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(contexts)} memory contexts for query '{query[:50]}...' "
|
||||
f"in {fetch_time:.1f}ms"
|
||||
)
|
||||
|
||||
return MemoryFetchResult(
|
||||
contexts=contexts,
|
||||
by_type=by_type,
|
||||
fetch_time_ms=fetch_time,
|
||||
query=query,
|
||||
)
|
||||
|
||||
async def _fetch_working(
|
||||
self,
|
||||
query: str,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from working memory."""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
# Filter keys by query (simple substring match)
|
||||
query_lower = query.lower()
|
||||
matched_keys = [k for k in all_keys if query_lower in k.lower()]
|
||||
|
||||
# If no query match, include all keys (working memory is always relevant)
|
||||
if not matched_keys and query:
|
||||
matched_keys = all_keys
|
||||
|
||||
for key in matched_keys[:limit]:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
async def _fetch_episodic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
days_back: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from episodic memory."""
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Search for similar episodes
|
||||
episodes = await episodic.search_similar(
|
||||
project_id=project_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
# Also get recent episodes if we didn't find enough
|
||||
if len(episodes) < limit // 2:
|
||||
since = datetime.now(UTC) - timedelta(days=days_back)
|
||||
recent = await episodic.get_recent(
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
since=since,
|
||||
)
|
||||
# Deduplicate by ID
|
||||
existing_ids = {e.id for e in episodes}
|
||||
for ep in recent:
|
||||
if ep.id not in existing_ids:
|
||||
episodes.append(ep)
|
||||
if len(episodes) >= limit:
|
||||
break
|
||||
|
||||
return [
|
||||
MemoryContext.from_episodic_memory(ep, query=query)
|
||||
for ep in episodes[:limit]
|
||||
]
|
||||
|
||||
async def _fetch_semantic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
limit: int,
|
||||
min_relevance: float,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from semantic memory."""
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
facts = await semantic.search_facts(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
min_confidence=min_relevance,
|
||||
)
|
||||
|
||||
return [
|
||||
MemoryContext.from_semantic_memory(fact, query=query)
|
||||
for fact in facts
|
||||
]
|
||||
|
||||
async def _fetch_procedural(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from procedural memory."""
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
procedures = await procedural.find_matching(
|
||||
context=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
MemoryContext.from_procedural_memory(proc, query=query)
|
||||
for proc in procedures
|
||||
]
|
||||
|
||||
async def fetch_all_working(
|
||||
self,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[MemoryContext]:
|
||||
"""
|
||||
Fetch all working memory for a session.
|
||||
|
||||
Useful for including entire session state in context.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
|
||||
Returns:
|
||||
List of MemoryContext for all working memory items
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
for key in all_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_memory_context_source(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> MemoryContextSource:
|
||||
"""Create a memory context source instance."""
|
||||
return MemoryContextSource(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
)
|
||||
629
backend/app/services/memory/integration/lifecycle.py
Normal file
629
backend/app/services/memory/integration/lifecycle.py
Normal file
@@ -0,0 +1,629 @@
|
||||
# app/services/memory/integration/lifecycle.py
|
||||
"""
|
||||
Agent Lifecycle Hooks for Memory System.
|
||||
|
||||
Provides memory management hooks for agent lifecycle events:
|
||||
- spawn: Initialize working memory for new agent instance
|
||||
- pause: Checkpoint working memory state
|
||||
- resume: Restore working memory from checkpoint
|
||||
- terminate: Consolidate session to episodic memory
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.types import EpisodeCreate, Outcome
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleEvent:
|
||||
"""Event data for lifecycle hooks."""
|
||||
|
||||
event_type: str # spawn, pause, resume, terminate
|
||||
project_id: UUID
|
||||
agent_instance_id: UUID
|
||||
agent_type_id: UUID | None = None
|
||||
session_id: str | None = None
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LifecycleResult:
|
||||
"""Result of a lifecycle operation."""
|
||||
|
||||
success: bool
|
||||
event_type: str
|
||||
message: str | None = None
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
duration_ms: float = 0.0
|
||||
|
||||
|
||||
# Type alias for lifecycle hooks
|
||||
LifecycleHook = Callable[[LifecycleEvent], Coroutine[Any, Any, None]]
|
||||
|
||||
|
||||
class LifecycleHooks:
|
||||
"""
|
||||
Collection of lifecycle hooks.
|
||||
|
||||
Allows registration of custom hooks for lifecycle events.
|
||||
Hooks are called after the core memory operations.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize lifecycle hooks."""
|
||||
self._spawn_hooks: list[LifecycleHook] = []
|
||||
self._pause_hooks: list[LifecycleHook] = []
|
||||
self._resume_hooks: list[LifecycleHook] = []
|
||||
self._terminate_hooks: list[LifecycleHook] = []
|
||||
|
||||
def on_spawn(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a spawn hook."""
|
||||
self._spawn_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_pause(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a pause hook."""
|
||||
self._pause_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_resume(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a resume hook."""
|
||||
self._resume_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
def on_terminate(self, hook: LifecycleHook) -> LifecycleHook:
|
||||
"""Register a terminate hook."""
|
||||
self._terminate_hooks.append(hook)
|
||||
return hook
|
||||
|
||||
async def run_spawn_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all spawn hooks."""
|
||||
for hook in self._spawn_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Spawn hook failed: {e}")
|
||||
|
||||
async def run_pause_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all pause hooks."""
|
||||
for hook in self._pause_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Pause hook failed: {e}")
|
||||
|
||||
async def run_resume_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all resume hooks."""
|
||||
for hook in self._resume_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Resume hook failed: {e}")
|
||||
|
||||
async def run_terminate_hooks(self, event: LifecycleEvent) -> None:
|
||||
"""Run all terminate hooks."""
|
||||
for hook in self._terminate_hooks:
|
||||
try:
|
||||
await hook(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Terminate hook failed: {e}")
|
||||
|
||||
|
||||
class AgentLifecycleManager:
|
||||
"""
|
||||
Manager for agent lifecycle and memory integration.
|
||||
|
||||
Handles memory operations during agent lifecycle events:
|
||||
- spawn: Creates new working memory for the session
|
||||
- pause: Saves working memory state to checkpoint
|
||||
- resume: Restores working memory from checkpoint
|
||||
- terminate: Consolidates working memory to episodic memory
|
||||
"""
|
||||
|
||||
# Key prefix for checkpoint storage
|
||||
CHECKPOINT_PREFIX = "__checkpoint__"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
hooks: LifecycleHooks | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the lifecycle manager.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
hooks: Optional lifecycle hooks
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._hooks = hooks or LifecycleHooks()
|
||||
|
||||
# Lazy-initialized services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
@property
|
||||
def hooks(self) -> LifecycleHooks:
|
||||
"""Get the lifecycle hooks."""
|
||||
return self._hooks
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
agent_type_id: UUID | None = None,
|
||||
initial_state: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent spawn - initialize working memory.
|
||||
|
||||
Creates a new working memory instance for the agent session
|
||||
and optionally populates it with initial state.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID for working memory
|
||||
agent_type_id: Optional agent type ID
|
||||
initial_state: Optional initial state to populate
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with spawn outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Create working memory for the session
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Populate initial state if provided
|
||||
items_set = 0
|
||||
if initial_state:
|
||||
for key, value in initial_state.items():
|
||||
await working.set(key, value)
|
||||
items_set += 1
|
||||
|
||||
# Create and run event hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
session_id=session_id,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
await self._hooks.run_spawn_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} spawned with session {session_id}, "
|
||||
f"initial state: {items_set} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="spawn",
|
||||
message="Agent spawned successfully",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"initial_items": items_set,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Spawn failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="spawn",
|
||||
message=f"Spawn failed: {e}",
|
||||
)
|
||||
|
||||
async def pause(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
checkpoint_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent pause - checkpoint working memory.
|
||||
|
||||
Saves the current working memory state to a checkpoint
|
||||
that can be restored later with resume().
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
checkpoint_id: Optional checkpoint identifier
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with checkpoint data
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
checkpoint_id = checkpoint_id or f"checkpoint_{int(start_time.timestamp())}"
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Get all current state
|
||||
all_keys = await working.list_keys()
|
||||
# Filter out checkpoint keys
|
||||
state_keys = [k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)]
|
||||
|
||||
state: dict[str, Any] = {}
|
||||
for key in state_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
state[key] = value
|
||||
|
||||
# Store checkpoint
|
||||
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
await working.set(
|
||||
checkpoint_key,
|
||||
{
|
||||
"state": state,
|
||||
"timestamp": start_time.isoformat(),
|
||||
"keys_count": len(state),
|
||||
},
|
||||
ttl_seconds=86400 * 7, # Keep checkpoint for 7 days
|
||||
)
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="pause",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
||||
)
|
||||
await self._hooks.run_pause_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} paused, checkpoint {checkpoint_id} "
|
||||
f"saved with {len(state)} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="pause",
|
||||
message="Agent paused successfully",
|
||||
data={
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"items_saved": len(state),
|
||||
"timestamp": start_time.isoformat(),
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pause failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="pause",
|
||||
message=f"Pause failed: {e}",
|
||||
)
|
||||
|
||||
async def resume(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
checkpoint_id: str,
|
||||
clear_current: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent resume - restore from checkpoint.
|
||||
|
||||
Restores working memory state from a previously saved checkpoint.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
checkpoint_id: Checkpoint to restore from
|
||||
clear_current: Whether to clear current state before restoring
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with restore outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Get checkpoint
|
||||
checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
checkpoint = await working.get(checkpoint_key)
|
||||
|
||||
if checkpoint is None:
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message=f"Checkpoint '{checkpoint_id}' not found",
|
||||
)
|
||||
|
||||
# Clear current state if requested
|
||||
if clear_current:
|
||||
all_keys = await working.list_keys()
|
||||
for key in all_keys:
|
||||
if not key.startswith(self.CHECKPOINT_PREFIX):
|
||||
await working.delete(key)
|
||||
|
||||
# Restore state from checkpoint
|
||||
state = checkpoint.get("state", {})
|
||||
items_restored = 0
|
||||
for key, value in state.items():
|
||||
await working.set(key, value)
|
||||
items_restored += 1
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="resume",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "checkpoint_id": checkpoint_id},
|
||||
)
|
||||
await self._hooks.run_resume_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} resumed from checkpoint {checkpoint_id}, "
|
||||
f"restored {items_restored} items"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="resume",
|
||||
message="Agent resumed successfully",
|
||||
data={
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"items_restored": items_restored,
|
||||
"checkpoint_timestamp": checkpoint.get("timestamp"),
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Resume failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message=f"Resume failed: {e}",
|
||||
)
|
||||
|
||||
async def terminate(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
task_description: str | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
lessons_learned: list[str] | None = None,
|
||||
consolidate_to_episodic: bool = True,
|
||||
cleanup_working: bool = True,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> LifecycleResult:
|
||||
"""
|
||||
Handle agent termination - consolidate to episodic memory.
|
||||
|
||||
Consolidates the session's working memory into an episodic memory
|
||||
entry, then optionally cleans up the working memory.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
task_description: Description of what was accomplished
|
||||
outcome: Task outcome (SUCCESS, FAILURE, PARTIAL)
|
||||
lessons_learned: Optional list of lessons learned
|
||||
consolidate_to_episodic: Whether to create episodic entry
|
||||
cleanup_working: Whether to clear working memory
|
||||
metadata: Optional metadata for the event
|
||||
|
||||
Returns:
|
||||
LifecycleResult with termination outcome
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
try:
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
# Gather session state for consolidation
|
||||
all_keys = await working.list_keys()
|
||||
state_keys = [k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)]
|
||||
|
||||
session_state: dict[str, Any] = {}
|
||||
for key in state_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
session_state[key] = value
|
||||
|
||||
episode_id: str | None = None
|
||||
|
||||
# Consolidate to episodic memory
|
||||
if consolidate_to_episodic:
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
description = task_description or f"Session {session_id} completed"
|
||||
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
task_type="session_completion",
|
||||
task_description=description[:500],
|
||||
outcome=outcome,
|
||||
outcome_details=f"Session terminated with {len(session_state)} state items",
|
||||
actions=[
|
||||
{
|
||||
"type": "session_terminate",
|
||||
"state_keys": list(session_state.keys()),
|
||||
"outcome": outcome.value,
|
||||
}
|
||||
],
|
||||
context_summary=str(session_state)[:1000] if session_state else "",
|
||||
lessons_learned=lessons_learned or [],
|
||||
duration_seconds=0.0, # Unknown at this point
|
||||
tokens_used=0,
|
||||
importance_score=0.6, # Moderate importance for session ends
|
||||
)
|
||||
|
||||
episode = await episodic.record_episode(episode_data)
|
||||
episode_id = str(episode.id)
|
||||
|
||||
# Clean up working memory
|
||||
items_cleared = 0
|
||||
if cleanup_working:
|
||||
for key in all_keys:
|
||||
await working.delete(key)
|
||||
items_cleared += 1
|
||||
|
||||
# Run hooks
|
||||
event = LifecycleEvent(
|
||||
event_type="terminate",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
session_id=session_id,
|
||||
metadata={**(metadata or {}), "episode_id": episode_id},
|
||||
)
|
||||
await self._hooks.run_terminate_hooks(event)
|
||||
|
||||
duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Agent {agent_instance_id} terminated, session {session_id} "
|
||||
f"consolidated to episode {episode_id}"
|
||||
)
|
||||
|
||||
return LifecycleResult(
|
||||
success=True,
|
||||
event_type="terminate",
|
||||
message="Agent terminated successfully",
|
||||
data={
|
||||
"episode_id": episode_id,
|
||||
"state_items_consolidated": len(session_state),
|
||||
"items_cleared": items_cleared,
|
||||
"outcome": outcome.value,
|
||||
},
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Terminate failed for agent {agent_instance_id}: {e}")
|
||||
return LifecycleResult(
|
||||
success=False,
|
||||
event_type="terminate",
|
||||
message=f"Terminate failed: {e}",
|
||||
)
|
||||
|
||||
async def list_checkpoints(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID,
|
||||
session_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List available checkpoints for a session.
|
||||
|
||||
Args:
|
||||
project_id: Project scope
|
||||
agent_instance_id: Agent instance ID
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
List of checkpoint metadata dicts
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id),
|
||||
)
|
||||
|
||||
all_keys = await working.list_keys()
|
||||
checkpoints: list[dict[str, Any]] = []
|
||||
|
||||
for key in all_keys:
|
||||
if key.startswith(self.CHECKPOINT_PREFIX):
|
||||
checkpoint_id = key[len(self.CHECKPOINT_PREFIX):]
|
||||
checkpoint = await working.get(key)
|
||||
if checkpoint:
|
||||
checkpoints.append({
|
||||
"checkpoint_id": checkpoint_id,
|
||||
"timestamp": checkpoint.get("timestamp"),
|
||||
"keys_count": checkpoint.get("keys_count", 0),
|
||||
})
|
||||
|
||||
# Sort by timestamp (newest first)
|
||||
checkpoints.sort(
|
||||
key=lambda c: c.get("timestamp", ""),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return checkpoints
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_lifecycle_manager(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
hooks: LifecycleHooks | None = None,
|
||||
) -> AgentLifecycleManager:
|
||||
"""Create a lifecycle manager instance."""
|
||||
return AgentLifecycleManager(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
hooks=hooks,
|
||||
)
|
||||
262
backend/tests/unit/services/context/types/test_memory.py
Normal file
262
backend/tests/unit/services/context/types/test_memory.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# tests/unit/services/context/types/test_memory.py
|
||||
"""Tests for MemoryContext type."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.types import ContextType
|
||||
from app.services.context.types.memory import MemoryContext, MemorySubtype
|
||||
|
||||
|
||||
class TestMemorySubtype:
|
||||
"""Tests for MemorySubtype enum."""
|
||||
|
||||
def test_all_types_defined(self) -> None:
|
||||
"""All memory subtypes should be defined."""
|
||||
assert MemorySubtype.WORKING == "working"
|
||||
assert MemorySubtype.EPISODIC == "episodic"
|
||||
assert MemorySubtype.SEMANTIC == "semantic"
|
||||
assert MemorySubtype.PROCEDURAL == "procedural"
|
||||
|
||||
def test_enum_values(self) -> None:
|
||||
"""Enum values should match strings."""
|
||||
assert MemorySubtype.WORKING.value == "working"
|
||||
assert MemorySubtype("episodic") == MemorySubtype.EPISODIC
|
||||
|
||||
|
||||
class TestMemoryContext:
|
||||
"""Tests for MemoryContext class."""
|
||||
|
||||
def test_get_type_returns_memory(self) -> None:
|
||||
"""get_type should return MEMORY."""
|
||||
ctx = MemoryContext(content="test", source="test_source")
|
||||
assert ctx.get_type() == ContextType.MEMORY
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default values should be set correctly."""
|
||||
ctx = MemoryContext(content="test", source="test_source")
|
||||
assert ctx.memory_subtype == MemorySubtype.EPISODIC
|
||||
assert ctx.memory_id is None
|
||||
assert ctx.relevance_score == 0.0
|
||||
assert ctx.importance == 0.5
|
||||
|
||||
def test_to_dict_includes_memory_fields(self) -> None:
|
||||
"""to_dict should include memory-specific fields."""
|
||||
ctx = MemoryContext(
|
||||
content="test content",
|
||||
source="test_source",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
memory_id="mem-123",
|
||||
relevance_score=0.8,
|
||||
subject="User",
|
||||
predicate="prefers",
|
||||
object_value="dark mode",
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
|
||||
assert data["memory_subtype"] == "semantic"
|
||||
assert data["memory_id"] == "mem-123"
|
||||
assert data["relevance_score"] == 0.8
|
||||
assert data["subject"] == "User"
|
||||
assert data["predicate"] == "prefers"
|
||||
assert data["object_value"] == "dark mode"
|
||||
|
||||
def test_from_dict(self) -> None:
|
||||
"""from_dict should create correct MemoryContext."""
|
||||
data = {
|
||||
"content": "test content",
|
||||
"source": "test_source",
|
||||
"timestamp": "2024-01-01T00:00:00+00:00",
|
||||
"memory_subtype": "semantic",
|
||||
"memory_id": "mem-123",
|
||||
"relevance_score": 0.8,
|
||||
"subject": "Test",
|
||||
}
|
||||
|
||||
ctx = MemoryContext.from_dict(data)
|
||||
|
||||
assert ctx.content == "test content"
|
||||
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert ctx.memory_id == "mem-123"
|
||||
assert ctx.subject == "Test"
|
||||
|
||||
|
||||
class TestMemoryContextFromWorkingMemory:
|
||||
"""Tests for MemoryContext.from_working_memory."""
|
||||
|
||||
def test_creates_working_memory_context(self) -> None:
|
||||
"""Should create working memory context from key/value."""
|
||||
ctx = MemoryContext.from_working_memory(
|
||||
key="user_preferences",
|
||||
value={"theme": "dark"},
|
||||
source="working:sess-123",
|
||||
query="preferences",
|
||||
)
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.WORKING
|
||||
assert ctx.key == "user_preferences"
|
||||
assert "{'theme': 'dark'}" in ctx.content
|
||||
assert ctx.relevance_score == 1.0 # Working memory is always relevant
|
||||
assert ctx.importance == 0.8 # Higher importance
|
||||
|
||||
def test_string_value(self) -> None:
|
||||
"""Should handle string values."""
|
||||
ctx = MemoryContext.from_working_memory(
|
||||
key="current_task",
|
||||
value="Build authentication",
|
||||
)
|
||||
|
||||
assert ctx.content == "Build authentication"
|
||||
|
||||
|
||||
class TestMemoryContextFromEpisodicMemory:
|
||||
"""Tests for MemoryContext.from_episodic_memory."""
|
||||
|
||||
def test_creates_episodic_memory_context(self) -> None:
|
||||
"""Should create episodic memory context from episode."""
|
||||
episode = MagicMock()
|
||||
episode.id = uuid4()
|
||||
episode.task_description = "Implemented login feature"
|
||||
episode.task_type = "feature_implementation"
|
||||
episode.outcome = MagicMock(value="success")
|
||||
episode.importance_score = 0.9
|
||||
episode.session_id = "sess-123"
|
||||
episode.occurred_at = datetime.now(UTC)
|
||||
episode.lessons_learned = ["Use proper validation"]
|
||||
|
||||
ctx = MemoryContext.from_episodic_memory(episode, query="login")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.EPISODIC
|
||||
assert ctx.memory_id == str(episode.id)
|
||||
assert ctx.content == "Implemented login feature"
|
||||
assert ctx.task_type == "feature_implementation"
|
||||
assert ctx.outcome == "success"
|
||||
assert ctx.importance == 0.9
|
||||
|
||||
def test_handles_missing_outcome(self) -> None:
|
||||
"""Should handle episodes with no outcome."""
|
||||
episode = MagicMock()
|
||||
episode.id = uuid4()
|
||||
episode.task_description = "WIP task"
|
||||
episode.outcome = None
|
||||
episode.importance_score = 0.5
|
||||
episode.occurred_at = None
|
||||
|
||||
ctx = MemoryContext.from_episodic_memory(episode)
|
||||
|
||||
assert ctx.outcome is None
|
||||
|
||||
|
||||
class TestMemoryContextFromSemanticMemory:
|
||||
"""Tests for MemoryContext.from_semantic_memory."""
|
||||
|
||||
def test_creates_semantic_memory_context(self) -> None:
|
||||
"""Should create semantic memory context from fact."""
|
||||
fact = MagicMock()
|
||||
fact.id = uuid4()
|
||||
fact.subject = "User"
|
||||
fact.predicate = "prefers"
|
||||
fact.object = "dark mode"
|
||||
fact.confidence = 0.95
|
||||
|
||||
ctx = MemoryContext.from_semantic_memory(fact, query="user preferences")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert ctx.memory_id == str(fact.id)
|
||||
assert ctx.content == "User prefers dark mode"
|
||||
assert ctx.subject == "User"
|
||||
assert ctx.predicate == "prefers"
|
||||
assert ctx.object_value == "dark mode"
|
||||
assert ctx.relevance_score == 0.95
|
||||
|
||||
|
||||
class TestMemoryContextFromProceduralMemory:
|
||||
"""Tests for MemoryContext.from_procedural_memory."""
|
||||
|
||||
def test_creates_procedural_memory_context(self) -> None:
|
||||
"""Should create procedural memory context from procedure."""
|
||||
procedure = MagicMock()
|
||||
procedure.id = uuid4()
|
||||
procedure.name = "Deploy to Production"
|
||||
procedure.trigger_pattern = "When deploying to production"
|
||||
procedure.steps = [
|
||||
{"action": "run_tests"},
|
||||
{"action": "build_docker"},
|
||||
{"action": "deploy"},
|
||||
]
|
||||
procedure.success_rate = 0.85
|
||||
procedure.success_count = 10
|
||||
procedure.failure_count = 2
|
||||
|
||||
ctx = MemoryContext.from_procedural_memory(procedure, query="deploy")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.PROCEDURAL
|
||||
assert ctx.memory_id == str(procedure.id)
|
||||
assert "Deploy to Production" in ctx.content
|
||||
assert "When deploying to production" in ctx.content
|
||||
assert ctx.trigger == "When deploying to production"
|
||||
assert ctx.success_rate == 0.85
|
||||
assert ctx.metadata["steps_count"] == 3
|
||||
assert ctx.metadata["execution_count"] == 12
|
||||
|
||||
|
||||
class TestMemoryContextHelpers:
|
||||
"""Tests for MemoryContext helper methods."""
|
||||
|
||||
def test_is_working_memory(self) -> None:
|
||||
"""is_working_memory should return True for working memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.WORKING,
|
||||
)
|
||||
assert ctx.is_working_memory() is True
|
||||
assert ctx.is_episodic_memory() is False
|
||||
|
||||
def test_is_episodic_memory(self) -> None:
|
||||
"""is_episodic_memory should return True for episodic memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
)
|
||||
assert ctx.is_episodic_memory() is True
|
||||
assert ctx.is_semantic_memory() is False
|
||||
|
||||
def test_is_semantic_memory(self) -> None:
|
||||
"""is_semantic_memory should return True for semantic memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
)
|
||||
assert ctx.is_semantic_memory() is True
|
||||
assert ctx.is_procedural_memory() is False
|
||||
|
||||
def test_is_procedural_memory(self) -> None:
|
||||
"""is_procedural_memory should return True for procedural memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.PROCEDURAL,
|
||||
)
|
||||
assert ctx.is_procedural_memory() is True
|
||||
assert ctx.is_working_memory() is False
|
||||
|
||||
def test_get_formatted_source(self) -> None:
|
||||
"""get_formatted_source should return formatted string."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="episodic:12345678-1234-1234-1234-123456789012",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
memory_id="12345678-1234-1234-1234-123456789012",
|
||||
)
|
||||
|
||||
formatted = ctx.get_formatted_source()
|
||||
|
||||
assert "[episodic]" in formatted
|
||||
assert "12345678..." in formatted
|
||||
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/integration/__init__.py
|
||||
"""Tests for memory integration module."""
|
||||
@@ -0,0 +1,322 @@
|
||||
# tests/unit/services/memory/integration/test_context_source.py
|
||||
"""Tests for MemoryContextSource service."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.types.memory import MemorySubtype
|
||||
from app.services.memory.integration.context_source import (
|
||||
MemoryContextSource,
|
||||
MemoryFetchConfig,
|
||||
MemoryFetchResult,
|
||||
get_memory_context_source,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session() -> MagicMock:
|
||||
"""Create mock database session."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context_source(mock_session: MagicMock) -> MemoryContextSource:
|
||||
"""Create MemoryContextSource instance."""
|
||||
return MemoryContextSource(session=mock_session)
|
||||
|
||||
|
||||
class TestMemoryFetchConfig:
|
||||
"""Tests for MemoryFetchConfig."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default config values should be set correctly."""
|
||||
config = MemoryFetchConfig()
|
||||
|
||||
assert config.working_limit == 10
|
||||
assert config.episodic_limit == 10
|
||||
assert config.semantic_limit == 15
|
||||
assert config.procedural_limit == 5
|
||||
assert config.episodic_days_back == 30
|
||||
assert config.min_relevance == 0.3
|
||||
assert config.include_working is True
|
||||
assert config.include_episodic is True
|
||||
assert config.include_semantic is True
|
||||
assert config.include_procedural is True
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Custom config values should be respected."""
|
||||
config = MemoryFetchConfig(
|
||||
working_limit=5,
|
||||
include_working=False,
|
||||
)
|
||||
|
||||
assert config.working_limit == 5
|
||||
assert config.include_working is False
|
||||
|
||||
|
||||
class TestMemoryFetchResult:
|
||||
"""Tests for MemoryFetchResult."""
|
||||
|
||||
def test_stores_results(self) -> None:
|
||||
"""Result should store contexts and metadata."""
|
||||
result = MemoryFetchResult(
|
||||
contexts=[],
|
||||
by_type={"working": 0, "episodic": 5, "semantic": 3, "procedural": 0},
|
||||
fetch_time_ms=15.5,
|
||||
query="test query",
|
||||
)
|
||||
|
||||
assert result.contexts == []
|
||||
assert result.by_type["episodic"] == 5
|
||||
assert result.fetch_time_ms == 15.5
|
||||
assert result.query == "test query"
|
||||
|
||||
|
||||
class TestMemoryContextSource:
|
||||
"""Tests for MemoryContextSource service."""
|
||||
|
||||
async def test_fetch_context_empty_when_no_sources(
|
||||
self,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""fetch_context should return empty when all sources fail."""
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_episodic=False,
|
||||
include_semantic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="test",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert len(result.contexts) == 0
|
||||
assert result.by_type == {
|
||||
"working": 0,
|
||||
"episodic": 0,
|
||||
"semantic": 0,
|
||||
"procedural": 0,
|
||||
}
|
||||
|
||||
@patch("app.services.memory.integration.context_source.WorkingMemory")
|
||||
async def test_fetch_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch working memory when session_id provided."""
|
||||
# Setup mock - both keys should match the query "task"
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["current_task", "task_state"])
|
||||
mock_working.get = AsyncMock(side_effect=lambda k: {"key": k, "value": "test"})
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_episodic=False,
|
||||
include_semantic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="task", # Both keys contain "task"
|
||||
project_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["working"] == 2
|
||||
assert all(
|
||||
c.memory_subtype == MemorySubtype.WORKING for c in result.contexts
|
||||
)
|
||||
|
||||
@patch("app.services.memory.integration.context_source.EpisodicMemory")
|
||||
async def test_fetch_episodic_memory(
|
||||
self,
|
||||
mock_episodic_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch episodic memory."""
|
||||
# Setup mock episode
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
mock_episode.task_description = "Completed login feature"
|
||||
mock_episode.task_type = "feature"
|
||||
mock_episode.outcome = MagicMock(value="success")
|
||||
mock_episode.importance_score = 0.8
|
||||
mock_episode.occurred_at = datetime.now(UTC)
|
||||
mock_episode.lessons_learned = []
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.search_similar = AsyncMock(return_value=[mock_episode])
|
||||
mock_episodic.get_recent = AsyncMock(return_value=[])
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_semantic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="login",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["episodic"] == 1
|
||||
assert result.contexts[0].memory_subtype == MemorySubtype.EPISODIC
|
||||
assert "Completed login feature" in result.contexts[0].content
|
||||
|
||||
@patch("app.services.memory.integration.context_source.SemanticMemory")
|
||||
async def test_fetch_semantic_memory(
|
||||
self,
|
||||
mock_semantic_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch semantic memory."""
|
||||
# Setup mock fact
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.id = uuid4()
|
||||
mock_fact.subject = "User"
|
||||
mock_fact.predicate = "prefers"
|
||||
mock_fact.object = "dark mode"
|
||||
mock_fact.confidence = 0.9
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.search_facts = AsyncMock(return_value=[mock_fact])
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_episodic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="preferences",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["semantic"] == 1
|
||||
assert result.contexts[0].memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert "User prefers dark mode" in result.contexts[0].content
|
||||
|
||||
@patch("app.services.memory.integration.context_source.ProceduralMemory")
|
||||
async def test_fetch_procedural_memory(
|
||||
self,
|
||||
mock_procedural_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch procedural memory."""
|
||||
# Setup mock procedure
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.id = uuid4()
|
||||
mock_proc.name = "Deploy"
|
||||
mock_proc.trigger_pattern = "When deploying"
|
||||
mock_proc.steps = [{"action": "build"}, {"action": "test"}]
|
||||
mock_proc.success_rate = 0.9
|
||||
mock_proc.success_count = 9
|
||||
mock_proc.failure_count = 1
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching = AsyncMock(return_value=[mock_proc])
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_episodic=False,
|
||||
include_semantic=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="deploy",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["procedural"] == 1
|
||||
assert result.contexts[0].memory_subtype == MemorySubtype.PROCEDURAL
|
||||
assert "Deploy" in result.contexts[0].content
|
||||
|
||||
async def test_results_sorted_by_relevance(
|
||||
self,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Results should be sorted by relevance score."""
|
||||
with patch.object(
|
||||
context_source, "_fetch_episodic"
|
||||
) as mock_ep, patch.object(
|
||||
context_source, "_fetch_semantic"
|
||||
) as mock_sem:
|
||||
# Create contexts with different relevance scores
|
||||
from app.services.context.types.memory import MemoryContext
|
||||
|
||||
ctx_low = MemoryContext(
|
||||
content="low relevance",
|
||||
source="test",
|
||||
relevance_score=0.3,
|
||||
)
|
||||
ctx_high = MemoryContext(
|
||||
content="high relevance",
|
||||
source="test",
|
||||
relevance_score=0.9,
|
||||
)
|
||||
|
||||
mock_ep.return_value = [ctx_low]
|
||||
mock_sem.return_value = [ctx_high]
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="test",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Higher relevance should come first
|
||||
assert result.contexts[0].relevance_score == 0.9
|
||||
assert result.contexts[1].relevance_score == 0.3
|
||||
|
||||
@patch("app.services.memory.integration.context_source.WorkingMemory")
|
||||
async def test_fetch_all_working(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""fetch_all_working should return all working memory items."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2", "key3"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
contexts = await context_source.fetch_all_working(
|
||||
session_id="sess-123",
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert len(contexts) == 3
|
||||
assert all(c.memory_subtype == MemorySubtype.WORKING for c in contexts)
|
||||
|
||||
|
||||
class TestGetMemoryContextSource:
|
||||
"""Tests for factory function."""
|
||||
|
||||
async def test_creates_instance(self) -> None:
|
||||
"""Factory should create MemoryContextSource instance."""
|
||||
mock_session = MagicMock()
|
||||
|
||||
source = await get_memory_context_source(mock_session)
|
||||
|
||||
assert isinstance(source, MemoryContextSource)
|
||||
471
backend/tests/unit/services/memory/integration/test_lifecycle.py
Normal file
471
backend/tests/unit/services/memory/integration/test_lifecycle.py
Normal file
@@ -0,0 +1,471 @@
|
||||
# tests/unit/services/memory/integration/test_lifecycle.py
|
||||
"""Tests for Agent Lifecycle Hooks."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.integration.lifecycle import (
|
||||
AgentLifecycleManager,
|
||||
LifecycleEvent,
|
||||
LifecycleHooks,
|
||||
LifecycleResult,
|
||||
get_lifecycle_manager,
|
||||
)
|
||||
from app.services.memory.types import Outcome
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session() -> MagicMock:
|
||||
"""Create mock database session."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lifecycle_hooks() -> LifecycleHooks:
|
||||
"""Create lifecycle hooks instance."""
|
||||
return LifecycleHooks()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lifecycle_manager(mock_session: MagicMock) -> AgentLifecycleManager:
|
||||
"""Create lifecycle manager instance."""
|
||||
return AgentLifecycleManager(session=mock_session)
|
||||
|
||||
|
||||
class TestLifecycleEvent:
|
||||
"""Tests for LifecycleEvent dataclass."""
|
||||
|
||||
def test_creates_event(self) -> None:
|
||||
"""Should create event with required fields."""
|
||||
project_id = uuid4()
|
||||
agent_id = uuid4()
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_id,
|
||||
)
|
||||
|
||||
assert event.event_type == "spawn"
|
||||
assert event.project_id == project_id
|
||||
assert event.agent_instance_id == agent_id
|
||||
assert event.timestamp is not None
|
||||
assert event.metadata == {}
|
||||
|
||||
def test_with_optional_fields(self) -> None:
|
||||
"""Should include optional fields."""
|
||||
event = LifecycleEvent(
|
||||
event_type="terminate",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
metadata={"reason": "completed"},
|
||||
)
|
||||
|
||||
assert event.session_id == "sess-123"
|
||||
assert event.metadata["reason"] == "completed"
|
||||
|
||||
|
||||
class TestLifecycleResult:
|
||||
"""Tests for LifecycleResult dataclass."""
|
||||
|
||||
def test_success_result(self) -> None:
|
||||
"""Should create success result."""
|
||||
result = LifecycleResult(
|
||||
success=True,
|
||||
event_type="spawn",
|
||||
message="Agent spawned",
|
||||
data={"session_id": "sess-123"},
|
||||
duration_ms=10.5,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "spawn"
|
||||
assert result.data["session_id"] == "sess-123"
|
||||
|
||||
def test_failure_result(self) -> None:
|
||||
"""Should create failure result."""
|
||||
result = LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message="Checkpoint not found",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.message == "Checkpoint not found"
|
||||
|
||||
|
||||
class TestLifecycleHooks:
|
||||
"""Tests for LifecycleHooks class."""
|
||||
|
||||
def test_register_spawn_hook(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should register spawn hook."""
|
||||
async def my_hook(event: LifecycleEvent) -> None:
|
||||
pass
|
||||
|
||||
result = lifecycle_hooks.on_spawn(my_hook)
|
||||
|
||||
assert result is my_hook
|
||||
assert my_hook in lifecycle_hooks._spawn_hooks
|
||||
|
||||
def test_register_all_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should register hooks for all event types."""
|
||||
hooks = [
|
||||
lifecycle_hooks.on_spawn(AsyncMock()),
|
||||
lifecycle_hooks.on_pause(AsyncMock()),
|
||||
lifecycle_hooks.on_resume(AsyncMock()),
|
||||
lifecycle_hooks.on_terminate(AsyncMock()),
|
||||
]
|
||||
|
||||
assert len(lifecycle_hooks._spawn_hooks) == 1
|
||||
assert len(lifecycle_hooks._pause_hooks) == 1
|
||||
assert len(lifecycle_hooks._resume_hooks) == 1
|
||||
assert len(lifecycle_hooks._terminate_hooks) == 1
|
||||
|
||||
async def test_run_spawn_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should run all spawn hooks."""
|
||||
hook1 = AsyncMock()
|
||||
hook2 = AsyncMock()
|
||||
lifecycle_hooks.on_spawn(hook1)
|
||||
lifecycle_hooks.on_spawn(hook2)
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
)
|
||||
|
||||
await lifecycle_hooks.run_spawn_hooks(event)
|
||||
|
||||
hook1.assert_called_once_with(event)
|
||||
hook2.assert_called_once_with(event)
|
||||
|
||||
async def test_hook_failure_doesnt_stop_others(
|
||||
self, lifecycle_hooks: LifecycleHooks
|
||||
) -> None:
|
||||
"""Hook failure should not stop other hooks from running."""
|
||||
hook1 = AsyncMock(side_effect=ValueError("Oops"))
|
||||
hook2 = AsyncMock()
|
||||
lifecycle_hooks.on_pause(hook1)
|
||||
lifecycle_hooks.on_pause(hook2)
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="pause",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
)
|
||||
|
||||
await lifecycle_hooks.run_pause_hooks(event)
|
||||
|
||||
# hook2 should still be called even though hook1 failed
|
||||
hook2.assert_called_once()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerSpawn:
|
||||
"""Tests for AgentLifecycleManager.spawn."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_creates_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should create working memory for session."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "spawn"
|
||||
mock_working_cls.for_session.assert_called_once()
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_with_initial_state(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should populate initial state."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
initial_state={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["initial_items"] == 2
|
||||
assert mock_working.set.call_count == 2
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_runs_hooks(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should run registered hooks."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
hook = AsyncMock()
|
||||
lifecycle_manager.hooks.on_spawn(hook)
|
||||
|
||||
await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
hook.assert_called_once()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerPause:
|
||||
"""Tests for AgentLifecycleManager.pause."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_pause_creates_checkpoint(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Pause should create checkpoint of working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value={"data": "test"})
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.pause(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="ckpt-001",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "pause"
|
||||
assert result.data["checkpoint_id"] == "ckpt-001"
|
||||
assert result.data["items_saved"] == 2
|
||||
|
||||
# Should save checkpoint with state
|
||||
mock_working.set.assert_called_once()
|
||||
call_args = mock_working.set.call_args
|
||||
# Check positional arg (first arg is key)
|
||||
assert "__checkpoint__ckpt-001" in call_args[0][0]
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_pause_generates_checkpoint_id(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Pause should generate checkpoint ID if not provided."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=[])
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.pause(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert "checkpoint_id" in result.data
|
||||
assert result.data["checkpoint_id"].startswith("checkpoint_")
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerResume:
|
||||
"""Tests for AgentLifecycleManager.resume."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_resume_restores_checkpoint(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Resume should restore working memory from checkpoint."""
|
||||
checkpoint_data = {
|
||||
"state": {"key1": "value1", "key2": "value2"},
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"keys_count": 2,
|
||||
}
|
||||
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=[])
|
||||
mock_working.get = AsyncMock(return_value=checkpoint_data)
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.resume(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="ckpt-001",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "resume"
|
||||
assert result.data["items_restored"] == 2
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_resume_checkpoint_not_found(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Resume should fail if checkpoint not found."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.get = AsyncMock(return_value=None)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.resume(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="nonexistent",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.message.lower()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerTerminate:
|
||||
"""Tests for AgentLifecycleManager.terminate."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.EpisodicMemory")
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_terminate_consolidates_to_episodic(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
mock_episodic_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Terminate should consolidate working memory to episodic."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await lifecycle_manager.terminate(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
task_description="Completed task",
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "terminate"
|
||||
assert result.data["episode_id"] == str(mock_episode.id)
|
||||
assert result.data["state_items_consolidated"] == 2
|
||||
mock_episodic.record_episode.assert_called_once()
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_terminate_cleans_up_working(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Terminate should clean up working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.terminate(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
consolidate_to_episodic=False,
|
||||
cleanup_working=True,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["items_cleared"] == 2
|
||||
assert mock_working.delete.call_count == 2
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerListCheckpoints:
|
||||
"""Tests for AgentLifecycleManager.list_checkpoints."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_list_checkpoints(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Should list available checkpoints."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(
|
||||
return_value=[
|
||||
"__checkpoint__ckpt-001",
|
||||
"__checkpoint__ckpt-002",
|
||||
"regular_key",
|
||||
]
|
||||
)
|
||||
mock_working.get = AsyncMock(
|
||||
return_value={
|
||||
"timestamp": "2024-01-01T00:00:00Z",
|
||||
"keys_count": 5,
|
||||
}
|
||||
)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
checkpoints = await lifecycle_manager.list_checkpoints(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert len(checkpoints) == 2
|
||||
assert checkpoints[0]["checkpoint_id"] == "ckpt-001"
|
||||
assert checkpoints[0]["keys_count"] == 5
|
||||
|
||||
|
||||
class TestGetLifecycleManager:
|
||||
"""Tests for factory function."""
|
||||
|
||||
async def test_creates_instance(self) -> None:
|
||||
"""Factory should create AgentLifecycleManager instance."""
|
||||
mock_session = MagicMock()
|
||||
|
||||
manager = await get_lifecycle_manager(mock_session)
|
||||
|
||||
assert isinstance(manager, AgentLifecycleManager)
|
||||
|
||||
async def test_with_custom_hooks(self) -> None:
|
||||
"""Factory should accept custom hooks."""
|
||||
mock_session = MagicMock()
|
||||
custom_hooks = LifecycleHooks()
|
||||
|
||||
manager = await get_lifecycle_manager(mock_session, hooks=custom_hooks)
|
||||
|
||||
assert manager.hooks is custom_hooks
|
||||
Reference in New Issue
Block a user