diff --git a/backend/app/services/context/__init__.py b/backend/app/services/context/__init__.py index 9be69e5..093eae1 100644 --- a/backend/app/services/context/__init__.py +++ b/backend/app/services/context/__init__.py @@ -114,6 +114,8 @@ from .types import ( ContextType, ConversationContext, KnowledgeContext, + MemoryContext, + MemorySubtype, MessageRole, SystemContext, TaskComplexity, @@ -149,6 +151,8 @@ __all__ = [ "FormattingError", "InvalidContextError", "KnowledgeContext", + "MemoryContext", + "MemorySubtype", "MessageRole", "ModelAdapter", "OpenAIAdapter", diff --git a/backend/app/services/context/budget/allocator.py b/backend/app/services/context/budget/allocator.py index 6c9507a..ab4eaee 100644 --- a/backend/app/services/context/budget/allocator.py +++ b/backend/app/services/context/budget/allocator.py @@ -30,6 +30,7 @@ class TokenBudget: knowledge: int = 0 conversation: int = 0 tools: int = 0 + memory: int = 0 # Agent memory (working, episodic, semantic, procedural) response_reserve: int = 0 buffer: int = 0 @@ -60,6 +61,7 @@ class TokenBudget: "knowledge": self.knowledge, "conversation": self.conversation, "tool": self.tools, + "memory": self.memory, } return allocation_map.get(context_type, 0) @@ -211,6 +213,7 @@ class TokenBudget: "knowledge": self.knowledge, "conversation": self.conversation, "tools": self.tools, + "memory": self.memory, "response_reserve": self.response_reserve, "buffer": self.buffer, }, @@ -264,9 +267,10 @@ class BudgetAllocator: total=total_tokens, system=int(total_tokens * alloc.get("system", 0.05)), task=int(total_tokens * alloc.get("task", 0.10)), - knowledge=int(total_tokens * alloc.get("knowledge", 0.40)), - conversation=int(total_tokens * alloc.get("conversation", 0.20)), + knowledge=int(total_tokens * alloc.get("knowledge", 0.30)), + conversation=int(total_tokens * alloc.get("conversation", 0.15)), tools=int(total_tokens * alloc.get("tools", 0.05)), + memory=int(total_tokens * alloc.get("memory", 0.15)), response_reserve=int(total_tokens * alloc.get("response", 0.15)), buffer=int(total_tokens * alloc.get("buffer", 0.05)), ) @@ -317,6 +321,8 @@ class BudgetAllocator: budget.conversation = max(0, budget.conversation + actual_adjustment) elif context_type == "tool": budget.tools = max(0, budget.tools + actual_adjustment) + elif context_type == "memory": + budget.memory = max(0, budget.memory + actual_adjustment) return budget @@ -338,7 +344,7 @@ class BudgetAllocator: Rebalanced budget """ if prioritize is None: - prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM] + prioritize = [ContextType.KNOWLEDGE, ContextType.MEMORY, ContextType.TASK, ContextType.SYSTEM] # Calculate unused tokens per type unused: dict[str, int] = {} diff --git a/backend/app/services/context/engine.py b/backend/app/services/context/engine.py index 707d570..d900119 100644 --- a/backend/app/services/context/engine.py +++ b/backend/app/services/context/engine.py @@ -7,6 +7,7 @@ Provides a high-level API for assembling optimized context for LLM requests. import logging from typing import TYPE_CHECKING, Any +from uuid import UUID from .assembly import ContextPipeline from .budget import BudgetAllocator, TokenBudget, TokenCalculator @@ -20,6 +21,7 @@ from .types import ( BaseContext, ConversationContext, KnowledgeContext, + MemoryContext, MessageRole, SystemContext, TaskContext, @@ -30,6 +32,7 @@ if TYPE_CHECKING: from redis.asyncio import Redis from app.services.mcp.client_manager import MCPClientManager + from app.services.memory.integration import MemoryContextSource logger = logging.getLogger(__name__) @@ -64,6 +67,7 @@ class ContextEngine: mcp_manager: "MCPClientManager | None" = None, redis: "Redis | None" = None, settings: ContextSettings | None = None, + memory_source: "MemoryContextSource | None" = None, ) -> None: """ Initialize the context engine. @@ -72,9 +76,11 @@ class ContextEngine: mcp_manager: MCP client manager for LLM Gateway/Knowledge Base redis: Redis connection for caching settings: Context settings + memory_source: Optional memory context source for agent memory """ self._mcp = mcp_manager self._settings = settings or get_context_settings() + self._memory_source = memory_source # Initialize components self._calculator = TokenCalculator(mcp_manager=mcp_manager) @@ -115,6 +121,15 @@ class ContextEngine: """ self._cache.set_redis(redis) + def set_memory_source(self, memory_source: "MemoryContextSource") -> None: + """ + Set memory context source for agent memory integration. + + Args: + memory_source: Memory context source + """ + self._memory_source = memory_source + async def assemble_context( self, project_id: str, @@ -126,6 +141,10 @@ class ContextEngine: task_description: str | None = None, knowledge_query: str | None = None, knowledge_limit: int = 10, + memory_query: str | None = None, + memory_limit: int = 20, + session_id: str | None = None, + agent_type_id: str | None = None, conversation_history: list[dict[str, str]] | None = None, tool_results: list[dict[str, Any]] | None = None, custom_contexts: list[BaseContext] | None = None, @@ -151,6 +170,10 @@ class ContextEngine: task_description: Current task description knowledge_query: Query for knowledge base search knowledge_limit: Max number of knowledge results + memory_query: Query for agent memory search + memory_limit: Max number of memory results + session_id: Session ID for working memory access + agent_type_id: Agent type ID for procedural memory conversation_history: List of {"role": str, "content": str} tool_results: List of tool results to include custom_contexts: Additional custom contexts @@ -197,15 +220,27 @@ class ContextEngine: ) contexts.extend(knowledge_contexts) - # 4. Conversation history + # 4. Memory context from Agent Memory System + if memory_query and self._memory_source: + memory_contexts = await self._fetch_memory( + project_id=project_id, + agent_id=agent_id, + query=memory_query, + limit=memory_limit, + session_id=session_id, + agent_type_id=agent_type_id, + ) + contexts.extend(memory_contexts) + + # 5. Conversation history if conversation_history: contexts.extend(self._convert_conversation(conversation_history)) - # 5. Tool results + # 6. Tool results if tool_results: contexts.extend(self._convert_tool_results(tool_results)) - # 6. Custom contexts + # 7. Custom contexts if custom_contexts: contexts.extend(custom_contexts) @@ -308,6 +343,65 @@ class ContextEngine: logger.warning(f"Failed to fetch knowledge: {e}") return [] + async def _fetch_memory( + self, + project_id: str, + agent_id: str, + query: str, + limit: int = 20, + session_id: str | None = None, + agent_type_id: str | None = None, + ) -> list[MemoryContext]: + """ + Fetch relevant memories from Agent Memory System. + + Args: + project_id: Project identifier + agent_id: Agent identifier + query: Search query + limit: Maximum results + session_id: Session ID for working memory + agent_type_id: Agent type ID for procedural memory + + Returns: + List of MemoryContext instances + """ + if not self._memory_source: + return [] + + try: + # Import here to avoid circular imports + + # Configure fetch limits + from app.services.memory.integration.context_source import MemoryFetchConfig + + config = MemoryFetchConfig( + working_limit=min(limit // 4, 5), + episodic_limit=min(limit // 2, 10), + semantic_limit=min(limit // 2, 10), + procedural_limit=min(limit // 4, 5), + include_working=session_id is not None, + ) + + result = await self._memory_source.fetch_context( + query=query, + project_id=UUID(project_id), + agent_instance_id=UUID(agent_id) if agent_id else None, + agent_type_id=UUID(agent_type_id) if agent_type_id else None, + session_id=session_id, + config=config, + ) + + logger.debug( + f"Fetched {len(result.contexts)} memory contexts for query: {query}, " + f"by_type: {result.by_type}" + ) + return result.contexts[:limit] + + except Exception as e: + logger.warning(f"Failed to fetch memory: {e}") + return [] + def _convert_conversation( self, history: list[dict[str, str]], @@ -466,6 +560,7 @@ def create_context_engine( mcp_manager: "MCPClientManager | None" = None, redis: "Redis | None" = None, settings: ContextSettings | None = None, + memory_source: "MemoryContextSource | None" = None, ) -> ContextEngine: """ Create a context engine instance. @@ -474,6 +569,7 @@ def create_context_engine( mcp_manager: MCP client manager redis: Redis connection settings: Context settings + memory_source: Optional memory context source Returns: Configured ContextEngine instance @@ -482,4 +578,5 @@ def create_context_engine( mcp_manager=mcp_manager, redis=redis, settings=settings, + memory_source=memory_source, ) diff --git a/backend/app/services/context/types/__init__.py b/backend/app/services/context/types/__init__.py index 4304025..0205037 100644 --- a/backend/app/services/context/types/__init__.py +++ b/backend/app/services/context/types/__init__.py @@ -15,6 +15,10 @@ from .conversation import ( MessageRole, ) from .knowledge import KnowledgeContext +from .memory import ( + MemoryContext, + MemorySubtype, +) from .system import SystemContext from .task import ( TaskComplexity, @@ -33,6 +37,8 @@ __all__ = [ "ContextType", "ConversationContext", "KnowledgeContext", + "MemoryContext", + "MemorySubtype", "MessageRole", "SystemContext", "TaskComplexity", diff --git a/backend/app/services/context/types/base.py b/backend/app/services/context/types/base.py index 6eef658..8913cc9 100644 --- a/backend/app/services/context/types/base.py +++ b/backend/app/services/context/types/base.py @@ -26,6 +26,7 @@ class ContextType(str, Enum): KNOWLEDGE = "knowledge" CONVERSATION = "conversation" TOOL = "tool" + MEMORY = "memory" # Agent memory (working, episodic, semantic, procedural) @classmethod def from_string(cls, value: str) -> "ContextType": diff --git a/backend/app/services/context/types/memory.py b/backend/app/services/context/types/memory.py new file mode 100644 index 0000000..5dc2509 --- /dev/null +++ b/backend/app/services/context/types/memory.py @@ -0,0 +1,282 @@ +""" +Memory Context Type. + +Represents agent memory as context for LLM requests. +Includes working, episodic, semantic, and procedural memories. +""" + +from dataclasses import dataclass, field +from datetime import UTC, datetime +from enum import Enum +from typing import Any + +from .base import BaseContext, ContextPriority, ContextType + + +class MemorySubtype(str, Enum): + """Types of agent memory.""" + + WORKING = "working" # Session-scoped temporary data + EPISODIC = "episodic" # Task history and outcomes + SEMANTIC = "semantic" # Facts and knowledge + PROCEDURAL = "procedural" # Learned procedures + + +@dataclass(eq=False) +class MemoryContext(BaseContext): + """ + Context from agent memory system. + + Memory context represents data retrieved from the agent + memory system, including: + - Working memory: Current session state + - Episodic memory: Past task experiences + - Semantic memory: Learned facts and knowledge + - Procedural memory: Known procedures and workflows + + Each memory item includes relevance scoring from search. + """ + + # Memory-specific fields + memory_subtype: MemorySubtype = field(default=MemorySubtype.EPISODIC) + memory_id: str | None = field(default=None) + relevance_score: float = field(default=0.0) + importance: float = field(default=0.5) + search_query: str = field(default="") + + # Type-specific fields (populated based on memory_subtype) + key: str | None = field(default=None) # For working memory + task_type: str | None = field(default=None) # For episodic + outcome: str | None = field(default=None) # For episodic + subject: str | None = field(default=None) # For semantic + predicate: str | None = field(default=None) # For semantic + object_value: str | None = field(default=None) # For semantic + trigger: str | None = field(default=None) # For procedural + success_rate: float | None = field(default=None) # For procedural + + def get_type(self) -> ContextType: + """Return MEMORY context type.""" + return ContextType.MEMORY + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary with memory-specific fields.""" + base = super().to_dict() + base.update( + { + "memory_subtype": self.memory_subtype.value, + "memory_id": self.memory_id, + "relevance_score": self.relevance_score, + "importance": self.importance, + "search_query": self.search_query, + "key": self.key, + "task_type": self.task_type, + "outcome": self.outcome, + "subject": self.subject, + "predicate": self.predicate, + "object_value": self.object_value, + "trigger": self.trigger, + "success_rate": self.success_rate, + } + ) + return base + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "MemoryContext": + """Create MemoryContext from dictionary.""" + return cls( + id=data.get("id", ""), + content=data["content"], + source=data["source"], + timestamp=datetime.fromisoformat(data["timestamp"]) + if isinstance(data.get("timestamp"), str) + else data.get("timestamp", datetime.now(UTC)), + priority=data.get("priority", ContextPriority.NORMAL.value), + metadata=data.get("metadata", {}), + memory_subtype=MemorySubtype(data.get("memory_subtype", "episodic")), + memory_id=data.get("memory_id"), + relevance_score=data.get("relevance_score", 0.0), + importance=data.get("importance", 0.5), + search_query=data.get("search_query", ""), + key=data.get("key"), + task_type=data.get("task_type"), + outcome=data.get("outcome"), + subject=data.get("subject"), + predicate=data.get("predicate"), + object_value=data.get("object_value"), + trigger=data.get("trigger"), + success_rate=data.get("success_rate"), + ) + + @classmethod + def from_working_memory( + cls, + key: str, + value: Any, + source: str = "working_memory", + query: str = "", + ) -> "MemoryContext": + """ + Create MemoryContext from working memory entry. + + Args: + key: Working memory key + value: Value stored at key + source: Source identifier + query: Search query used + + Returns: + MemoryContext instance + """ + return cls( + content=str(value), + source=source, + memory_subtype=MemorySubtype.WORKING, + key=key, + relevance_score=1.0, # Working memory is always relevant + importance=0.8, # Higher importance for current session state + search_query=query, + priority=ContextPriority.HIGH.value, + ) + + @classmethod + def from_episodic_memory( + cls, + episode: Any, + query: str = "", + ) -> "MemoryContext": + """ + Create MemoryContext from episodic memory episode. + + Args: + episode: Episode object from episodic memory + query: Search query used + + Returns: + MemoryContext instance + """ + outcome_val = None + if hasattr(episode, "outcome") and episode.outcome: + outcome_val = ( + episode.outcome.value + if hasattr(episode.outcome, "value") + else str(episode.outcome) + ) + + return cls( + content=episode.task_description, + source=f"episodic:{episode.id}", + memory_subtype=MemorySubtype.EPISODIC, + memory_id=str(episode.id), + relevance_score=getattr(episode, "importance_score", 0.5), + importance=getattr(episode, "importance_score", 0.5), + search_query=query, + task_type=getattr(episode, "task_type", None), + outcome=outcome_val, + metadata={ + "session_id": getattr(episode, "session_id", None), + "occurred_at": episode.occurred_at.isoformat() + if hasattr(episode, "occurred_at") and episode.occurred_at + else None, + "lessons_learned": getattr(episode, "lessons_learned", []), + }, + ) + + @classmethod + def from_semantic_memory( + cls, + fact: Any, + query: str = "", + ) -> "MemoryContext": + """ + Create MemoryContext from semantic memory fact. + + Args: + fact: Fact object from semantic memory + query: Search query used + + Returns: + MemoryContext instance + """ + triple = f"{fact.subject} {fact.predicate} {fact.object}" + return cls( + content=triple, + source=f"semantic:{fact.id}", + memory_subtype=MemorySubtype.SEMANTIC, + memory_id=str(fact.id), + relevance_score=getattr(fact, "confidence", 0.5), + importance=getattr(fact, "confidence", 0.5), + search_query=query, + subject=fact.subject, + predicate=fact.predicate, + object_value=fact.object, + priority=ContextPriority.NORMAL.value, + ) + + @classmethod + def from_procedural_memory( + cls, + procedure: Any, + query: str = "", + ) -> "MemoryContext": + """ + Create MemoryContext from procedural memory procedure. + + Args: + procedure: Procedure object from procedural memory + query: Search query used + + Returns: + MemoryContext instance + """ + # Format steps as content + steps = getattr(procedure, "steps", []) + steps_content = "\n".join( + f" {i + 1}. {step.get('action', step) if isinstance(step, dict) else step}" + for i, step in enumerate(steps) + ) + content = f"Procedure: {procedure.name}\nTrigger: {procedure.trigger_pattern}\nSteps:\n{steps_content}" + + return cls( + content=content, + source=f"procedural:{procedure.id}", + memory_subtype=MemorySubtype.PROCEDURAL, + memory_id=str(procedure.id), + relevance_score=getattr(procedure, "success_rate", 0.5), + importance=0.7, # Procedures are moderately important + search_query=query, + trigger=procedure.trigger_pattern, + success_rate=getattr(procedure, "success_rate", None), + metadata={ + "steps_count": len(steps), + "execution_count": getattr(procedure, "success_count", 0) + + getattr(procedure, "failure_count", 0), + }, + ) + + def is_working_memory(self) -> bool: + """Check if this is working memory.""" + return self.memory_subtype == MemorySubtype.WORKING + + def is_episodic_memory(self) -> bool: + """Check if this is episodic memory.""" + return self.memory_subtype == MemorySubtype.EPISODIC + + def is_semantic_memory(self) -> bool: + """Check if this is semantic memory.""" + return self.memory_subtype == MemorySubtype.SEMANTIC + + def is_procedural_memory(self) -> bool: + """Check if this is procedural memory.""" + return self.memory_subtype == MemorySubtype.PROCEDURAL + + def get_formatted_source(self) -> str: + """ + Get a formatted source string for display. + + Returns: + Formatted source string + """ + parts = [f"[{self.memory_subtype.value}]", self.source] + if self.memory_id: + parts.append(f"({self.memory_id[:8]}...)") + return " ".join(parts) diff --git a/backend/app/services/memory/integration/__init__.py b/backend/app/services/memory/integration/__init__.py new file mode 100644 index 0000000..a988274 --- /dev/null +++ b/backend/app/services/memory/integration/__init__.py @@ -0,0 +1,19 @@ +# app/services/memory/integration/__init__.py +""" +Memory Integration Module. + +Provides integration between the agent memory system and other Syndarix components: +- Context Engine: Memory as context source +- Agent Lifecycle: Spawn, pause, resume, terminate hooks +""" + +from .context_source import MemoryContextSource, get_memory_context_source +from .lifecycle import AgentLifecycleManager, LifecycleHooks, get_lifecycle_manager + +__all__ = [ + "AgentLifecycleManager", + "LifecycleHooks", + "MemoryContextSource", + "get_lifecycle_manager", + "get_memory_context_source", +] diff --git a/backend/app/services/memory/integration/context_source.py b/backend/app/services/memory/integration/context_source.py new file mode 100644 index 0000000..262fcb4 --- /dev/null +++ b/backend/app/services/memory/integration/context_source.py @@ -0,0 +1,402 @@ +# app/services/memory/integration/context_source.py +""" +Memory Context Source. + +Provides agent memory as a context source for the Context Engine. +Retrieves relevant memories based on query and converts them to MemoryContext objects. +""" + +import logging +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.context.types.memory import MemoryContext +from app.services.memory.episodic import EpisodicMemory +from app.services.memory.procedural import ProceduralMemory +from app.services.memory.semantic import SemanticMemory +from app.services.memory.working import WorkingMemory + +logger = logging.getLogger(__name__) + + +@dataclass +class MemoryFetchConfig: + """Configuration for memory fetching.""" + + # Limits per memory type + working_limit: int = 10 + episodic_limit: int = 10 + semantic_limit: int = 15 + procedural_limit: int = 5 + + # Time ranges + episodic_days_back: int = 30 + min_relevance: float = 0.3 + + # Which memory types to include + include_working: bool = True + include_episodic: bool = True + include_semantic: bool = True + include_procedural: bool = True + + +@dataclass +class MemoryFetchResult: + """Result of memory fetch operation.""" + + contexts: list[MemoryContext] + by_type: dict[str, int] + fetch_time_ms: float + query: str + + +class MemoryContextSource: + """ + Source for memory context in the Context Engine. + + This service retrieves relevant memories based on a query and + converts them to MemoryContext objects for context assembly. + It coordinates between all memory types (working, episodic, + semantic, procedural) to provide a comprehensive memory context. + """ + + def __init__( + self, + session: AsyncSession, + embedding_generator: Any | None = None, + ) -> None: + """ + Initialize the memory context source. + + Args: + session: Database session + embedding_generator: Optional embedding generator for semantic search + """ + self._session = session + self._embedding_generator = embedding_generator + + # Lazy-initialized memory services + self._episodic: EpisodicMemory | None = None + self._semantic: SemanticMemory | None = None + self._procedural: ProceduralMemory | None = None + + async def _get_episodic(self) -> EpisodicMemory: + """Get or create episodic memory service.""" + if self._episodic is None: + self._episodic = await EpisodicMemory.create( + self._session, + self._embedding_generator, + ) + return self._episodic + + async def _get_semantic(self) -> SemanticMemory: + """Get or create semantic memory service.""" + if self._semantic is None: + self._semantic = await SemanticMemory.create( + self._session, + self._embedding_generator, + ) + return self._semantic + + async def _get_procedural(self) -> ProceduralMemory: + """Get or create procedural memory service.""" + if self._procedural is None: + self._procedural = await ProceduralMemory.create( + self._session, + self._embedding_generator, + ) + return self._procedural + + async def fetch_context( + self, + query: str, + project_id: UUID, + agent_instance_id: UUID | None = None, + agent_type_id: UUID | None = None, + session_id: str | None = None, + config: MemoryFetchConfig | None = None, + ) -> MemoryFetchResult: + """ + Fetch relevant memories as context. + + This is the main entry point for the Context Engine integration. + It searches across all memory types and returns relevant memories + as MemoryContext objects. + + Args: + query: Search query for finding relevant memories + project_id: Project scope + agent_instance_id: Optional agent instance scope + agent_type_id: Optional agent type scope (for procedural) + session_id: Optional session ID (for working memory) + config: Optional fetch configuration + + Returns: + MemoryFetchResult with contexts and metadata + """ + config = config or MemoryFetchConfig() + start_time = datetime.now(UTC) + + contexts: list[MemoryContext] = [] + by_type: dict[str, int] = { + "working": 0, + "episodic": 0, + "semantic": 0, + "procedural": 0, + } + + # Fetch from working memory (session-scoped) + if config.include_working and session_id: + try: + working_contexts = await self._fetch_working( + query=query, + session_id=session_id, + project_id=project_id, + agent_instance_id=agent_instance_id, + limit=config.working_limit, + ) + contexts.extend(working_contexts) + by_type["working"] = len(working_contexts) + except Exception as e: + logger.warning(f"Failed to fetch working memory: {e}") + + # Fetch from episodic memory + if config.include_episodic: + try: + episodic_contexts = await self._fetch_episodic( + query=query, + project_id=project_id, + agent_instance_id=agent_instance_id, + limit=config.episodic_limit, + days_back=config.episodic_days_back, + ) + contexts.extend(episodic_contexts) + by_type["episodic"] = len(episodic_contexts) + except Exception as e: + logger.warning(f"Failed to fetch episodic memory: {e}") + + # Fetch from semantic memory + if config.include_semantic: + try: + semantic_contexts = await self._fetch_semantic( + query=query, + project_id=project_id, + limit=config.semantic_limit, + min_relevance=config.min_relevance, + ) + contexts.extend(semantic_contexts) + by_type["semantic"] = len(semantic_contexts) + except Exception as e: + logger.warning(f"Failed to fetch semantic memory: {e}") + + # Fetch from procedural memory + if config.include_procedural: + try: + procedural_contexts = await self._fetch_procedural( + query=query, + project_id=project_id, + agent_type_id=agent_type_id, + limit=config.procedural_limit, + ) + contexts.extend(procedural_contexts) + by_type["procedural"] = len(procedural_contexts) + except Exception as e: + logger.warning(f"Failed to fetch procedural memory: {e}") + + # Sort by relevance + contexts.sort(key=lambda c: c.relevance_score, reverse=True) + + fetch_time = (datetime.now(UTC) - start_time).total_seconds() * 1000 + + logger.debug( + f"Fetched {len(contexts)} memory contexts for query '{query[:50]}...' " + f"in {fetch_time:.1f}ms" + ) + + return MemoryFetchResult( + contexts=contexts, + by_type=by_type, + fetch_time_ms=fetch_time, + query=query, + ) + + async def _fetch_working( + self, + query: str, + session_id: str, + project_id: UUID, + agent_instance_id: UUID | None, + limit: int, + ) -> list[MemoryContext]: + """Fetch from working memory.""" + working = await WorkingMemory.for_session( + session_id=session_id, + project_id=str(project_id), + agent_instance_id=str(agent_instance_id) if agent_instance_id else None, + ) + + contexts: list[MemoryContext] = [] + all_keys = await working.list_keys() + + # Filter keys by query (simple substring match) + query_lower = query.lower() + matched_keys = [k for k in all_keys if query_lower in k.lower()] + + # If no query match, include all keys (working memory is always relevant) + if not matched_keys and query: + matched_keys = all_keys + + for key in matched_keys[:limit]: + value = await working.get(key) + if value is not None: + contexts.append( + MemoryContext.from_working_memory( + key=key, + value=value, + source=f"working:{session_id}", + query=query, + ) + ) + + return contexts + + async def _fetch_episodic( + self, + query: str, + project_id: UUID, + agent_instance_id: UUID | None, + limit: int, + days_back: int, + ) -> list[MemoryContext]: + """Fetch from episodic memory.""" + episodic = await self._get_episodic() + + # Search for similar episodes + episodes = await episodic.search_similar( + project_id=project_id, + query=query, + limit=limit, + agent_instance_id=agent_instance_id, + ) + + # Also get recent episodes if we didn't find enough + if len(episodes) < limit // 2: + since = datetime.now(UTC) - timedelta(days=days_back) + recent = await episodic.get_recent( + project_id=project_id, + limit=limit, + since=since, + ) + # Deduplicate by ID + existing_ids = {e.id for e in episodes} + for ep in recent: + if ep.id not in existing_ids: + episodes.append(ep) + if len(episodes) >= limit: + break + + return [ + MemoryContext.from_episodic_memory(ep, query=query) + for ep in episodes[:limit] + ] + + async def _fetch_semantic( + self, + query: str, + project_id: UUID, + limit: int, + min_relevance: float, + ) -> list[MemoryContext]: + """Fetch from semantic memory.""" + semantic = await self._get_semantic() + + facts = await semantic.search_facts( + query=query, + project_id=project_id, + limit=limit, + min_confidence=min_relevance, + ) + + return [ + MemoryContext.from_semantic_memory(fact, query=query) + for fact in facts + ] + + async def _fetch_procedural( + self, + query: str, + project_id: UUID, + agent_type_id: UUID | None, + limit: int, + ) -> list[MemoryContext]: + """Fetch from procedural memory.""" + procedural = await self._get_procedural() + + procedures = await procedural.find_matching( + context=query, + project_id=project_id, + agent_type_id=agent_type_id, + limit=limit, + ) + + return [ + MemoryContext.from_procedural_memory(proc, query=query) + for proc in procedures + ] + + async def fetch_all_working( + self, + session_id: str, + project_id: UUID, + agent_instance_id: UUID | None = None, + ) -> list[MemoryContext]: + """ + Fetch all working memory for a session. + + Useful for including entire session state in context. + + Args: + session_id: Session ID + project_id: Project scope + agent_instance_id: Optional agent instance scope + + Returns: + List of MemoryContext for all working memory items + """ + working = await WorkingMemory.for_session( + session_id=session_id, + project_id=str(project_id), + agent_instance_id=str(agent_instance_id) if agent_instance_id else None, + ) + + contexts: list[MemoryContext] = [] + all_keys = await working.list_keys() + + for key in all_keys: + value = await working.get(key) + if value is not None: + contexts.append( + MemoryContext.from_working_memory( + key=key, + value=value, + source=f"working:{session_id}", + ) + ) + + return contexts + + +# Factory function +async def get_memory_context_source( + session: AsyncSession, + embedding_generator: Any | None = None, +) -> MemoryContextSource: + """Create a memory context source instance.""" + return MemoryContextSource( + session=session, + embedding_generator=embedding_generator, + ) diff --git a/backend/app/services/memory/integration/lifecycle.py b/backend/app/services/memory/integration/lifecycle.py new file mode 100644 index 0000000..76a0502 --- /dev/null +++ b/backend/app/services/memory/integration/lifecycle.py @@ -0,0 +1,629 @@ +# app/services/memory/integration/lifecycle.py +""" +Agent Lifecycle Hooks for Memory System. + +Provides memory management hooks for agent lifecycle events: +- spawn: Initialize working memory for new agent instance +- pause: Checkpoint working memory state +- resume: Restore working memory from checkpoint +- terminate: Consolidate session to episodic memory +""" + +import logging +from collections.abc import Callable, Coroutine +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.memory.episodic import EpisodicMemory +from app.services.memory.types import EpisodeCreate, Outcome +from app.services.memory.working import WorkingMemory + +logger = logging.getLogger(__name__) + + +@dataclass +class LifecycleEvent: + """Event data for lifecycle hooks.""" + + event_type: str # spawn, pause, resume, terminate + project_id: UUID + agent_instance_id: UUID + agent_type_id: UUID | None = None + session_id: str | None = None + timestamp: datetime = field(default_factory=lambda: datetime.now(UTC)) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class LifecycleResult: + """Result of a lifecycle operation.""" + + success: bool + event_type: str + message: str | None = None + data: dict[str, Any] = field(default_factory=dict) + duration_ms: float = 0.0 + + +# Type alias for lifecycle hooks +LifecycleHook = Callable[[LifecycleEvent], Coroutine[Any, Any, None]] + + +class LifecycleHooks: + """ + Collection of lifecycle hooks. + + Allows registration of custom hooks for lifecycle events. + Hooks are called after the core memory operations. + """ + + def __init__(self) -> None: + """Initialize lifecycle hooks.""" + self._spawn_hooks: list[LifecycleHook] = [] + self._pause_hooks: list[LifecycleHook] = [] + self._resume_hooks: list[LifecycleHook] = [] + self._terminate_hooks: list[LifecycleHook] = [] + + def on_spawn(self, hook: LifecycleHook) -> LifecycleHook: + """Register a spawn hook.""" + self._spawn_hooks.append(hook) + return hook + + def on_pause(self, hook: LifecycleHook) -> LifecycleHook: + """Register a pause hook.""" + self._pause_hooks.append(hook) + return hook + + def on_resume(self, hook: LifecycleHook) -> LifecycleHook: + """Register a resume hook.""" + self._resume_hooks.append(hook) + return hook + + def on_terminate(self, hook: LifecycleHook) -> LifecycleHook: + """Register a terminate hook.""" + self._terminate_hooks.append(hook) + return hook + + async def run_spawn_hooks(self, event: LifecycleEvent) -> None: + """Run all spawn hooks.""" + for hook in self._spawn_hooks: + try: + await hook(event) + except Exception as e: + logger.warning(f"Spawn hook failed: {e}") + + async def run_pause_hooks(self, event: LifecycleEvent) -> None: + """Run all pause hooks.""" + for hook in self._pause_hooks: + try: + await hook(event) + except Exception as e: + logger.warning(f"Pause hook failed: {e}") + + async def run_resume_hooks(self, event: LifecycleEvent) -> None: + """Run all resume hooks.""" + for hook in self._resume_hooks: + try: + await hook(event) + except Exception as e: + logger.warning(f"Resume hook failed: {e}") + + async def run_terminate_hooks(self, event: LifecycleEvent) -> None: + """Run all terminate hooks.""" + for hook in self._terminate_hooks: + try: + await hook(event) + except Exception as e: + logger.warning(f"Terminate hook failed: {e}") + + +class AgentLifecycleManager: + """ + Manager for agent lifecycle and memory integration. + + Handles memory operations during agent lifecycle events: + - spawn: Creates new working memory for the session + - pause: Saves working memory state to checkpoint + - resume: Restores working memory from checkpoint + - terminate: Consolidates working memory to episodic memory + """ + + # Key prefix for checkpoint storage + CHECKPOINT_PREFIX = "__checkpoint__" + + def __init__( + self, + session: AsyncSession, + embedding_generator: Any | None = None, + hooks: LifecycleHooks | None = None, + ) -> None: + """ + Initialize the lifecycle manager. + + Args: + session: Database session + embedding_generator: Optional embedding generator + hooks: Optional lifecycle hooks + """ + self._session = session + self._embedding_generator = embedding_generator + self._hooks = hooks or LifecycleHooks() + + # Lazy-initialized services + self._episodic: EpisodicMemory | None = None + + async def _get_episodic(self) -> EpisodicMemory: + """Get or create episodic memory service.""" + if self._episodic is None: + self._episodic = await EpisodicMemory.create( + self._session, + self._embedding_generator, + ) + return self._episodic + + @property + def hooks(self) -> LifecycleHooks: + """Get the lifecycle hooks.""" + return self._hooks + + async def spawn( + self, + project_id: UUID, + agent_instance_id: UUID, + session_id: str, + agent_type_id: UUID | None = None, + initial_state: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> LifecycleResult: + """ + Handle agent spawn - initialize working memory. + + Creates a new working memory instance for the agent session + and optionally populates it with initial state. + + Args: + project_id: Project scope + agent_instance_id: Agent instance ID + session_id: Session ID for working memory + agent_type_id: Optional agent type ID + initial_state: Optional initial state to populate + metadata: Optional metadata for the event + + Returns: + LifecycleResult with spawn outcome + """ + start_time = datetime.now(UTC) + + try: + # Create working memory for the session + working = await WorkingMemory.for_session( + session_id=session_id, + project_id=str(project_id), + agent_instance_id=str(agent_instance_id), + ) + + # Populate initial state if provided + items_set = 0 + if initial_state: + for key, value in initial_state.items(): + await working.set(key, value) + items_set += 1 + + # Create and run event hooks + event = LifecycleEvent( + event_type="spawn", + project_id=project_id, + agent_instance_id=agent_instance_id, + agent_type_id=agent_type_id, + session_id=session_id, + metadata=metadata or {}, + ) + await self._hooks.run_spawn_hooks(event) + + duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000 + + logger.info( + f"Agent {agent_instance_id} spawned with session {session_id}, " + f"initial state: {items_set} items" + ) + + return LifecycleResult( + success=True, + event_type="spawn", + message="Agent spawned successfully", + data={ + "session_id": session_id, + "initial_items": items_set, + }, + duration_ms=duration_ms, + ) + + except Exception as e: + logger.error(f"Spawn failed for agent {agent_instance_id}: {e}") + return LifecycleResult( + success=False, + event_type="spawn", + message=f"Spawn failed: {e}", + ) + + async def pause( + self, + project_id: UUID, + agent_instance_id: UUID, + session_id: str, + checkpoint_id: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> LifecycleResult: + """ + Handle agent pause - checkpoint working memory. + + Saves the current working memory state to a checkpoint + that can be restored later with resume(). + + Args: + project_id: Project scope + agent_instance_id: Agent instance ID + session_id: Session ID + checkpoint_id: Optional checkpoint identifier + metadata: Optional metadata for the event + + Returns: + LifecycleResult with checkpoint data + """ + start_time = datetime.now(UTC) + checkpoint_id = checkpoint_id or f"checkpoint_{int(start_time.timestamp())}" + + try: + working = await WorkingMemory.for_session( + session_id=session_id, + project_id=str(project_id), + agent_instance_id=str(agent_instance_id), + ) + + # Get all current state + all_keys = await working.list_keys() + # Filter out checkpoint keys + state_keys = [k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)] + + state: dict[str, Any] = {} + for key in state_keys: + value = await working.get(key) + if value is not None: + state[key] = value + + # Store checkpoint + checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}" + await working.set( + checkpoint_key, + { + "state": state, + "timestamp": start_time.isoformat(), + "keys_count": len(state), + }, + ttl_seconds=86400 * 7, # Keep checkpoint for 7 days + ) + + # Run hooks + event = LifecycleEvent( + event_type="pause", + project_id=project_id, + agent_instance_id=agent_instance_id, + session_id=session_id, + metadata={**(metadata or {}), "checkpoint_id": checkpoint_id}, + ) + await self._hooks.run_pause_hooks(event) + + duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000 + + logger.info( + f"Agent {agent_instance_id} paused, checkpoint {checkpoint_id} " + f"saved with {len(state)} items" + ) + + return LifecycleResult( + success=True, + event_type="pause", + message="Agent paused successfully", + data={ + "checkpoint_id": checkpoint_id, + "items_saved": len(state), + "timestamp": start_time.isoformat(), + }, + duration_ms=duration_ms, + ) + + except Exception as e: + logger.error(f"Pause failed for agent {agent_instance_id}: {e}") + return LifecycleResult( + success=False, + event_type="pause", + message=f"Pause failed: {e}", + ) + + async def resume( + self, + project_id: UUID, + agent_instance_id: UUID, + session_id: str, + checkpoint_id: str, + clear_current: bool = True, + metadata: dict[str, Any] | None = None, + ) -> LifecycleResult: + """ + Handle agent resume - restore from checkpoint. + + Restores working memory state from a previously saved checkpoint. + + Args: + project_id: Project scope + agent_instance_id: Agent instance ID + session_id: Session ID + checkpoint_id: Checkpoint to restore from + clear_current: Whether to clear current state before restoring + metadata: Optional metadata for the event + + Returns: + LifecycleResult with restore outcome + """ + start_time = datetime.now(UTC) + + try: + working = await WorkingMemory.for_session( + session_id=session_id, + project_id=str(project_id), + agent_instance_id=str(agent_instance_id), + ) + + # Get checkpoint + checkpoint_key = f"{self.CHECKPOINT_PREFIX}{checkpoint_id}" + checkpoint = await working.get(checkpoint_key) + + if checkpoint is None: + return LifecycleResult( + success=False, + event_type="resume", + message=f"Checkpoint '{checkpoint_id}' not found", + ) + + # Clear current state if requested + if clear_current: + all_keys = await working.list_keys() + for key in all_keys: + if not key.startswith(self.CHECKPOINT_PREFIX): + await working.delete(key) + + # Restore state from checkpoint + state = checkpoint.get("state", {}) + items_restored = 0 + for key, value in state.items(): + await working.set(key, value) + items_restored += 1 + + # Run hooks + event = LifecycleEvent( + event_type="resume", + project_id=project_id, + agent_instance_id=agent_instance_id, + session_id=session_id, + metadata={**(metadata or {}), "checkpoint_id": checkpoint_id}, + ) + await self._hooks.run_resume_hooks(event) + + duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000 + + logger.info( + f"Agent {agent_instance_id} resumed from checkpoint {checkpoint_id}, " + f"restored {items_restored} items" + ) + + return LifecycleResult( + success=True, + event_type="resume", + message="Agent resumed successfully", + data={ + "checkpoint_id": checkpoint_id, + "items_restored": items_restored, + "checkpoint_timestamp": checkpoint.get("timestamp"), + }, + duration_ms=duration_ms, + ) + + except Exception as e: + logger.error(f"Resume failed for agent {agent_instance_id}: {e}") + return LifecycleResult( + success=False, + event_type="resume", + message=f"Resume failed: {e}", + ) + + async def terminate( + self, + project_id: UUID, + agent_instance_id: UUID, + session_id: str, + task_description: str | None = None, + outcome: Outcome = Outcome.SUCCESS, + lessons_learned: list[str] | None = None, + consolidate_to_episodic: bool = True, + cleanup_working: bool = True, + metadata: dict[str, Any] | None = None, + ) -> LifecycleResult: + """ + Handle agent termination - consolidate to episodic memory. + + Consolidates the session's working memory into an episodic memory + entry, then optionally cleans up the working memory. + + Args: + project_id: Project scope + agent_instance_id: Agent instance ID + session_id: Session ID + task_description: Description of what was accomplished + outcome: Task outcome (SUCCESS, FAILURE, PARTIAL) + lessons_learned: Optional list of lessons learned + consolidate_to_episodic: Whether to create episodic entry + cleanup_working: Whether to clear working memory + metadata: Optional metadata for the event + + Returns: + LifecycleResult with termination outcome + """ + start_time = datetime.now(UTC) + + try: + working = await WorkingMemory.for_session( + session_id=session_id, + project_id=str(project_id), + agent_instance_id=str(agent_instance_id), + ) + + # Gather session state for consolidation + all_keys = await working.list_keys() + state_keys = [k for k in all_keys if not k.startswith(self.CHECKPOINT_PREFIX)] + + session_state: dict[str, Any] = {} + for key in state_keys: + value = await working.get(key) + if value is not None: + session_state[key] = value + + episode_id: str | None = None + + # Consolidate to episodic memory + if consolidate_to_episodic: + episodic = await self._get_episodic() + + description = task_description or f"Session {session_id} completed" + + episode_data = EpisodeCreate( + project_id=project_id, + agent_instance_id=agent_instance_id, + session_id=session_id, + task_type="session_completion", + task_description=description[:500], + outcome=outcome, + outcome_details=f"Session terminated with {len(session_state)} state items", + actions=[ + { + "type": "session_terminate", + "state_keys": list(session_state.keys()), + "outcome": outcome.value, + } + ], + context_summary=str(session_state)[:1000] if session_state else "", + lessons_learned=lessons_learned or [], + duration_seconds=0.0, # Unknown at this point + tokens_used=0, + importance_score=0.6, # Moderate importance for session ends + ) + + episode = await episodic.record_episode(episode_data) + episode_id = str(episode.id) + + # Clean up working memory + items_cleared = 0 + if cleanup_working: + for key in all_keys: + await working.delete(key) + items_cleared += 1 + + # Run hooks + event = LifecycleEvent( + event_type="terminate", + project_id=project_id, + agent_instance_id=agent_instance_id, + session_id=session_id, + metadata={**(metadata or {}), "episode_id": episode_id}, + ) + await self._hooks.run_terminate_hooks(event) + + duration_ms = (datetime.now(UTC) - start_time).total_seconds() * 1000 + + logger.info( + f"Agent {agent_instance_id} terminated, session {session_id} " + f"consolidated to episode {episode_id}" + ) + + return LifecycleResult( + success=True, + event_type="terminate", + message="Agent terminated successfully", + data={ + "episode_id": episode_id, + "state_items_consolidated": len(session_state), + "items_cleared": items_cleared, + "outcome": outcome.value, + }, + duration_ms=duration_ms, + ) + + except Exception as e: + logger.error(f"Terminate failed for agent {agent_instance_id}: {e}") + return LifecycleResult( + success=False, + event_type="terminate", + message=f"Terminate failed: {e}", + ) + + async def list_checkpoints( + self, + project_id: UUID, + agent_instance_id: UUID, + session_id: str, + ) -> list[dict[str, Any]]: + """ + List available checkpoints for a session. + + Args: + project_id: Project scope + agent_instance_id: Agent instance ID + session_id: Session ID + + Returns: + List of checkpoint metadata dicts + """ + working = await WorkingMemory.for_session( + session_id=session_id, + project_id=str(project_id), + agent_instance_id=str(agent_instance_id), + ) + + all_keys = await working.list_keys() + checkpoints: list[dict[str, Any]] = [] + + for key in all_keys: + if key.startswith(self.CHECKPOINT_PREFIX): + checkpoint_id = key[len(self.CHECKPOINT_PREFIX):] + checkpoint = await working.get(key) + if checkpoint: + checkpoints.append({ + "checkpoint_id": checkpoint_id, + "timestamp": checkpoint.get("timestamp"), + "keys_count": checkpoint.get("keys_count", 0), + }) + + # Sort by timestamp (newest first) + checkpoints.sort( + key=lambda c: c.get("timestamp", ""), + reverse=True, + ) + + return checkpoints + + +# Factory function +async def get_lifecycle_manager( + session: AsyncSession, + embedding_generator: Any | None = None, + hooks: LifecycleHooks | None = None, +) -> AgentLifecycleManager: + """Create a lifecycle manager instance.""" + return AgentLifecycleManager( + session=session, + embedding_generator=embedding_generator, + hooks=hooks, + ) diff --git a/backend/tests/unit/services/context/types/test_memory.py b/backend/tests/unit/services/context/types/test_memory.py new file mode 100644 index 0000000..dcfc2c4 --- /dev/null +++ b/backend/tests/unit/services/context/types/test_memory.py @@ -0,0 +1,262 @@ +# tests/unit/services/context/types/test_memory.py +"""Tests for MemoryContext type.""" + +from datetime import UTC, datetime +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from app.services.context.types import ContextType +from app.services.context.types.memory import MemoryContext, MemorySubtype + + +class TestMemorySubtype: + """Tests for MemorySubtype enum.""" + + def test_all_types_defined(self) -> None: + """All memory subtypes should be defined.""" + assert MemorySubtype.WORKING == "working" + assert MemorySubtype.EPISODIC == "episodic" + assert MemorySubtype.SEMANTIC == "semantic" + assert MemorySubtype.PROCEDURAL == "procedural" + + def test_enum_values(self) -> None: + """Enum values should match strings.""" + assert MemorySubtype.WORKING.value == "working" + assert MemorySubtype("episodic") == MemorySubtype.EPISODIC + + +class TestMemoryContext: + """Tests for MemoryContext class.""" + + def test_get_type_returns_memory(self) -> None: + """get_type should return MEMORY.""" + ctx = MemoryContext(content="test", source="test_source") + assert ctx.get_type() == ContextType.MEMORY + + def test_default_values(self) -> None: + """Default values should be set correctly.""" + ctx = MemoryContext(content="test", source="test_source") + assert ctx.memory_subtype == MemorySubtype.EPISODIC + assert ctx.memory_id is None + assert ctx.relevance_score == 0.0 + assert ctx.importance == 0.5 + + def test_to_dict_includes_memory_fields(self) -> None: + """to_dict should include memory-specific fields.""" + ctx = MemoryContext( + content="test content", + source="test_source", + memory_subtype=MemorySubtype.SEMANTIC, + memory_id="mem-123", + relevance_score=0.8, + subject="User", + predicate="prefers", + object_value="dark mode", + ) + + data = ctx.to_dict() + + assert data["memory_subtype"] == "semantic" + assert data["memory_id"] == "mem-123" + assert data["relevance_score"] == 0.8 + assert data["subject"] == "User" + assert data["predicate"] == "prefers" + assert data["object_value"] == "dark mode" + + def test_from_dict(self) -> None: + """from_dict should create correct MemoryContext.""" + data = { + "content": "test content", + "source": "test_source", + "timestamp": "2024-01-01T00:00:00+00:00", + "memory_subtype": "semantic", + "memory_id": "mem-123", + "relevance_score": 0.8, + "subject": "Test", + } + + ctx = MemoryContext.from_dict(data) + + assert ctx.content == "test content" + assert ctx.memory_subtype == MemorySubtype.SEMANTIC + assert ctx.memory_id == "mem-123" + assert ctx.subject == "Test" + + +class TestMemoryContextFromWorkingMemory: + """Tests for MemoryContext.from_working_memory.""" + + def test_creates_working_memory_context(self) -> None: + """Should create working memory context from key/value.""" + ctx = MemoryContext.from_working_memory( + key="user_preferences", + value={"theme": "dark"}, + source="working:sess-123", + query="preferences", + ) + + assert ctx.memory_subtype == MemorySubtype.WORKING + assert ctx.key == "user_preferences" + assert "{'theme': 'dark'}" in ctx.content + assert ctx.relevance_score == 1.0 # Working memory is always relevant + assert ctx.importance == 0.8 # Higher importance + + def test_string_value(self) -> None: + """Should handle string values.""" + ctx = MemoryContext.from_working_memory( + key="current_task", + value="Build authentication", + ) + + assert ctx.content == "Build authentication" + + +class TestMemoryContextFromEpisodicMemory: + """Tests for MemoryContext.from_episodic_memory.""" + + def test_creates_episodic_memory_context(self) -> None: + """Should create episodic memory context from episode.""" + episode = MagicMock() + episode.id = uuid4() + episode.task_description = "Implemented login feature" + episode.task_type = "feature_implementation" + episode.outcome = MagicMock(value="success") + episode.importance_score = 0.9 + episode.session_id = "sess-123" + episode.occurred_at = datetime.now(UTC) + episode.lessons_learned = ["Use proper validation"] + + ctx = MemoryContext.from_episodic_memory(episode, query="login") + + assert ctx.memory_subtype == MemorySubtype.EPISODIC + assert ctx.memory_id == str(episode.id) + assert ctx.content == "Implemented login feature" + assert ctx.task_type == "feature_implementation" + assert ctx.outcome == "success" + assert ctx.importance == 0.9 + + def test_handles_missing_outcome(self) -> None: + """Should handle episodes with no outcome.""" + episode = MagicMock() + episode.id = uuid4() + episode.task_description = "WIP task" + episode.outcome = None + episode.importance_score = 0.5 + episode.occurred_at = None + + ctx = MemoryContext.from_episodic_memory(episode) + + assert ctx.outcome is None + + +class TestMemoryContextFromSemanticMemory: + """Tests for MemoryContext.from_semantic_memory.""" + + def test_creates_semantic_memory_context(self) -> None: + """Should create semantic memory context from fact.""" + fact = MagicMock() + fact.id = uuid4() + fact.subject = "User" + fact.predicate = "prefers" + fact.object = "dark mode" + fact.confidence = 0.95 + + ctx = MemoryContext.from_semantic_memory(fact, query="user preferences") + + assert ctx.memory_subtype == MemorySubtype.SEMANTIC + assert ctx.memory_id == str(fact.id) + assert ctx.content == "User prefers dark mode" + assert ctx.subject == "User" + assert ctx.predicate == "prefers" + assert ctx.object_value == "dark mode" + assert ctx.relevance_score == 0.95 + + +class TestMemoryContextFromProceduralMemory: + """Tests for MemoryContext.from_procedural_memory.""" + + def test_creates_procedural_memory_context(self) -> None: + """Should create procedural memory context from procedure.""" + procedure = MagicMock() + procedure.id = uuid4() + procedure.name = "Deploy to Production" + procedure.trigger_pattern = "When deploying to production" + procedure.steps = [ + {"action": "run_tests"}, + {"action": "build_docker"}, + {"action": "deploy"}, + ] + procedure.success_rate = 0.85 + procedure.success_count = 10 + procedure.failure_count = 2 + + ctx = MemoryContext.from_procedural_memory(procedure, query="deploy") + + assert ctx.memory_subtype == MemorySubtype.PROCEDURAL + assert ctx.memory_id == str(procedure.id) + assert "Deploy to Production" in ctx.content + assert "When deploying to production" in ctx.content + assert ctx.trigger == "When deploying to production" + assert ctx.success_rate == 0.85 + assert ctx.metadata["steps_count"] == 3 + assert ctx.metadata["execution_count"] == 12 + + +class TestMemoryContextHelpers: + """Tests for MemoryContext helper methods.""" + + def test_is_working_memory(self) -> None: + """is_working_memory should return True for working memory.""" + ctx = MemoryContext( + content="test", + source="test", + memory_subtype=MemorySubtype.WORKING, + ) + assert ctx.is_working_memory() is True + assert ctx.is_episodic_memory() is False + + def test_is_episodic_memory(self) -> None: + """is_episodic_memory should return True for episodic memory.""" + ctx = MemoryContext( + content="test", + source="test", + memory_subtype=MemorySubtype.EPISODIC, + ) + assert ctx.is_episodic_memory() is True + assert ctx.is_semantic_memory() is False + + def test_is_semantic_memory(self) -> None: + """is_semantic_memory should return True for semantic memory.""" + ctx = MemoryContext( + content="test", + source="test", + memory_subtype=MemorySubtype.SEMANTIC, + ) + assert ctx.is_semantic_memory() is True + assert ctx.is_procedural_memory() is False + + def test_is_procedural_memory(self) -> None: + """is_procedural_memory should return True for procedural memory.""" + ctx = MemoryContext( + content="test", + source="test", + memory_subtype=MemorySubtype.PROCEDURAL, + ) + assert ctx.is_procedural_memory() is True + assert ctx.is_working_memory() is False + + def test_get_formatted_source(self) -> None: + """get_formatted_source should return formatted string.""" + ctx = MemoryContext( + content="test", + source="episodic:12345678-1234-1234-1234-123456789012", + memory_subtype=MemorySubtype.EPISODIC, + memory_id="12345678-1234-1234-1234-123456789012", + ) + + formatted = ctx.get_formatted_source() + + assert "[episodic]" in formatted + assert "12345678..." in formatted diff --git a/backend/tests/unit/services/memory/integration/__init__.py b/backend/tests/unit/services/memory/integration/__init__.py new file mode 100644 index 0000000..84fdb03 --- /dev/null +++ b/backend/tests/unit/services/memory/integration/__init__.py @@ -0,0 +1,2 @@ +# tests/unit/services/memory/integration/__init__.py +"""Tests for memory integration module.""" diff --git a/backend/tests/unit/services/memory/integration/test_context_source.py b/backend/tests/unit/services/memory/integration/test_context_source.py new file mode 100644 index 0000000..449488e --- /dev/null +++ b/backend/tests/unit/services/memory/integration/test_context_source.py @@ -0,0 +1,322 @@ +# tests/unit/services/memory/integration/test_context_source.py +"""Tests for MemoryContextSource service.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from app.services.context.types.memory import MemorySubtype +from app.services.memory.integration.context_source import ( + MemoryContextSource, + MemoryFetchConfig, + MemoryFetchResult, + get_memory_context_source, +) + +pytestmark = pytest.mark.asyncio(loop_scope="function") + + +@pytest.fixture +def mock_session() -> MagicMock: + """Create mock database session.""" + return MagicMock() + + +@pytest.fixture +def context_source(mock_session: MagicMock) -> MemoryContextSource: + """Create MemoryContextSource instance.""" + return MemoryContextSource(session=mock_session) + + +class TestMemoryFetchConfig: + """Tests for MemoryFetchConfig.""" + + def test_default_values(self) -> None: + """Default config values should be set correctly.""" + config = MemoryFetchConfig() + + assert config.working_limit == 10 + assert config.episodic_limit == 10 + assert config.semantic_limit == 15 + assert config.procedural_limit == 5 + assert config.episodic_days_back == 30 + assert config.min_relevance == 0.3 + assert config.include_working is True + assert config.include_episodic is True + assert config.include_semantic is True + assert config.include_procedural is True + + def test_custom_values(self) -> None: + """Custom config values should be respected.""" + config = MemoryFetchConfig( + working_limit=5, + include_working=False, + ) + + assert config.working_limit == 5 + assert config.include_working is False + + +class TestMemoryFetchResult: + """Tests for MemoryFetchResult.""" + + def test_stores_results(self) -> None: + """Result should store contexts and metadata.""" + result = MemoryFetchResult( + contexts=[], + by_type={"working": 0, "episodic": 5, "semantic": 3, "procedural": 0}, + fetch_time_ms=15.5, + query="test query", + ) + + assert result.contexts == [] + assert result.by_type["episodic"] == 5 + assert result.fetch_time_ms == 15.5 + assert result.query == "test query" + + +class TestMemoryContextSource: + """Tests for MemoryContextSource service.""" + + async def test_fetch_context_empty_when_no_sources( + self, + context_source: MemoryContextSource, + ) -> None: + """fetch_context should return empty when all sources fail.""" + config = MemoryFetchConfig( + include_working=False, + include_episodic=False, + include_semantic=False, + include_procedural=False, + ) + + result = await context_source.fetch_context( + query="test", + project_id=uuid4(), + config=config, + ) + + assert len(result.contexts) == 0 + assert result.by_type == { + "working": 0, + "episodic": 0, + "semantic": 0, + "procedural": 0, + } + + @patch("app.services.memory.integration.context_source.WorkingMemory") + async def test_fetch_working_memory( + self, + mock_working_cls: MagicMock, + context_source: MemoryContextSource, + ) -> None: + """Should fetch working memory when session_id provided.""" + # Setup mock - both keys should match the query "task" + mock_working = AsyncMock() + mock_working.list_keys = AsyncMock(return_value=["current_task", "task_state"]) + mock_working.get = AsyncMock(side_effect=lambda k: {"key": k, "value": "test"}) + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + config = MemoryFetchConfig( + include_episodic=False, + include_semantic=False, + include_procedural=False, + ) + + result = await context_source.fetch_context( + query="task", # Both keys contain "task" + project_id=uuid4(), + session_id="sess-123", + config=config, + ) + + assert result.by_type["working"] == 2 + assert all( + c.memory_subtype == MemorySubtype.WORKING for c in result.contexts + ) + + @patch("app.services.memory.integration.context_source.EpisodicMemory") + async def test_fetch_episodic_memory( + self, + mock_episodic_cls: MagicMock, + context_source: MemoryContextSource, + ) -> None: + """Should fetch episodic memory.""" + # Setup mock episode + mock_episode = MagicMock() + mock_episode.id = uuid4() + mock_episode.task_description = "Completed login feature" + mock_episode.task_type = "feature" + mock_episode.outcome = MagicMock(value="success") + mock_episode.importance_score = 0.8 + mock_episode.occurred_at = datetime.now(UTC) + mock_episode.lessons_learned = [] + + mock_episodic = AsyncMock() + mock_episodic.search_similar = AsyncMock(return_value=[mock_episode]) + mock_episodic.get_recent = AsyncMock(return_value=[]) + mock_episodic_cls.create = AsyncMock(return_value=mock_episodic) + + config = MemoryFetchConfig( + include_working=False, + include_semantic=False, + include_procedural=False, + ) + + result = await context_source.fetch_context( + query="login", + project_id=uuid4(), + config=config, + ) + + assert result.by_type["episodic"] == 1 + assert result.contexts[0].memory_subtype == MemorySubtype.EPISODIC + assert "Completed login feature" in result.contexts[0].content + + @patch("app.services.memory.integration.context_source.SemanticMemory") + async def test_fetch_semantic_memory( + self, + mock_semantic_cls: MagicMock, + context_source: MemoryContextSource, + ) -> None: + """Should fetch semantic memory.""" + # Setup mock fact + mock_fact = MagicMock() + mock_fact.id = uuid4() + mock_fact.subject = "User" + mock_fact.predicate = "prefers" + mock_fact.object = "dark mode" + mock_fact.confidence = 0.9 + + mock_semantic = AsyncMock() + mock_semantic.search_facts = AsyncMock(return_value=[mock_fact]) + mock_semantic_cls.create = AsyncMock(return_value=mock_semantic) + + config = MemoryFetchConfig( + include_working=False, + include_episodic=False, + include_procedural=False, + ) + + result = await context_source.fetch_context( + query="preferences", + project_id=uuid4(), + config=config, + ) + + assert result.by_type["semantic"] == 1 + assert result.contexts[0].memory_subtype == MemorySubtype.SEMANTIC + assert "User prefers dark mode" in result.contexts[0].content + + @patch("app.services.memory.integration.context_source.ProceduralMemory") + async def test_fetch_procedural_memory( + self, + mock_procedural_cls: MagicMock, + context_source: MemoryContextSource, + ) -> None: + """Should fetch procedural memory.""" + # Setup mock procedure + mock_proc = MagicMock() + mock_proc.id = uuid4() + mock_proc.name = "Deploy" + mock_proc.trigger_pattern = "When deploying" + mock_proc.steps = [{"action": "build"}, {"action": "test"}] + mock_proc.success_rate = 0.9 + mock_proc.success_count = 9 + mock_proc.failure_count = 1 + + mock_procedural = AsyncMock() + mock_procedural.find_matching = AsyncMock(return_value=[mock_proc]) + mock_procedural_cls.create = AsyncMock(return_value=mock_procedural) + + config = MemoryFetchConfig( + include_working=False, + include_episodic=False, + include_semantic=False, + ) + + result = await context_source.fetch_context( + query="deploy", + project_id=uuid4(), + config=config, + ) + + assert result.by_type["procedural"] == 1 + assert result.contexts[0].memory_subtype == MemorySubtype.PROCEDURAL + assert "Deploy" in result.contexts[0].content + + async def test_results_sorted_by_relevance( + self, + context_source: MemoryContextSource, + ) -> None: + """Results should be sorted by relevance score.""" + with patch.object( + context_source, "_fetch_episodic" + ) as mock_ep, patch.object( + context_source, "_fetch_semantic" + ) as mock_sem: + # Create contexts with different relevance scores + from app.services.context.types.memory import MemoryContext + + ctx_low = MemoryContext( + content="low relevance", + source="test", + relevance_score=0.3, + ) + ctx_high = MemoryContext( + content="high relevance", + source="test", + relevance_score=0.9, + ) + + mock_ep.return_value = [ctx_low] + mock_sem.return_value = [ctx_high] + + config = MemoryFetchConfig( + include_working=False, + include_procedural=False, + ) + + result = await context_source.fetch_context( + query="test", + project_id=uuid4(), + config=config, + ) + + # Higher relevance should come first + assert result.contexts[0].relevance_score == 0.9 + assert result.contexts[1].relevance_score == 0.3 + + @patch("app.services.memory.integration.context_source.WorkingMemory") + async def test_fetch_all_working( + self, + mock_working_cls: MagicMock, + context_source: MemoryContextSource, + ) -> None: + """fetch_all_working should return all working memory items.""" + mock_working = AsyncMock() + mock_working.list_keys = AsyncMock(return_value=["key1", "key2", "key3"]) + mock_working.get = AsyncMock(return_value="value") + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + contexts = await context_source.fetch_all_working( + session_id="sess-123", + project_id=uuid4(), + ) + + assert len(contexts) == 3 + assert all(c.memory_subtype == MemorySubtype.WORKING for c in contexts) + + +class TestGetMemoryContextSource: + """Tests for factory function.""" + + async def test_creates_instance(self) -> None: + """Factory should create MemoryContextSource instance.""" + mock_session = MagicMock() + + source = await get_memory_context_source(mock_session) + + assert isinstance(source, MemoryContextSource) diff --git a/backend/tests/unit/services/memory/integration/test_lifecycle.py b/backend/tests/unit/services/memory/integration/test_lifecycle.py new file mode 100644 index 0000000..a39c108 --- /dev/null +++ b/backend/tests/unit/services/memory/integration/test_lifecycle.py @@ -0,0 +1,471 @@ +# tests/unit/services/memory/integration/test_lifecycle.py +"""Tests for Agent Lifecycle Hooks.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from app.services.memory.integration.lifecycle import ( + AgentLifecycleManager, + LifecycleEvent, + LifecycleHooks, + LifecycleResult, + get_lifecycle_manager, +) +from app.services.memory.types import Outcome + +pytestmark = pytest.mark.asyncio(loop_scope="function") + + +@pytest.fixture +def mock_session() -> MagicMock: + """Create mock database session.""" + return MagicMock() + + +@pytest.fixture +def lifecycle_hooks() -> LifecycleHooks: + """Create lifecycle hooks instance.""" + return LifecycleHooks() + + +@pytest.fixture +def lifecycle_manager(mock_session: MagicMock) -> AgentLifecycleManager: + """Create lifecycle manager instance.""" + return AgentLifecycleManager(session=mock_session) + + +class TestLifecycleEvent: + """Tests for LifecycleEvent dataclass.""" + + def test_creates_event(self) -> None: + """Should create event with required fields.""" + project_id = uuid4() + agent_id = uuid4() + + event = LifecycleEvent( + event_type="spawn", + project_id=project_id, + agent_instance_id=agent_id, + ) + + assert event.event_type == "spawn" + assert event.project_id == project_id + assert event.agent_instance_id == agent_id + assert event.timestamp is not None + assert event.metadata == {} + + def test_with_optional_fields(self) -> None: + """Should include optional fields.""" + event = LifecycleEvent( + event_type="terminate", + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + metadata={"reason": "completed"}, + ) + + assert event.session_id == "sess-123" + assert event.metadata["reason"] == "completed" + + +class TestLifecycleResult: + """Tests for LifecycleResult dataclass.""" + + def test_success_result(self) -> None: + """Should create success result.""" + result = LifecycleResult( + success=True, + event_type="spawn", + message="Agent spawned", + data={"session_id": "sess-123"}, + duration_ms=10.5, + ) + + assert result.success is True + assert result.event_type == "spawn" + assert result.data["session_id"] == "sess-123" + + def test_failure_result(self) -> None: + """Should create failure result.""" + result = LifecycleResult( + success=False, + event_type="resume", + message="Checkpoint not found", + ) + + assert result.success is False + assert result.message == "Checkpoint not found" + + +class TestLifecycleHooks: + """Tests for LifecycleHooks class.""" + + def test_register_spawn_hook(self, lifecycle_hooks: LifecycleHooks) -> None: + """Should register spawn hook.""" + async def my_hook(event: LifecycleEvent) -> None: + pass + + result = lifecycle_hooks.on_spawn(my_hook) + + assert result is my_hook + assert my_hook in lifecycle_hooks._spawn_hooks + + def test_register_all_hooks(self, lifecycle_hooks: LifecycleHooks) -> None: + """Should register hooks for all event types.""" + hooks = [ + lifecycle_hooks.on_spawn(AsyncMock()), + lifecycle_hooks.on_pause(AsyncMock()), + lifecycle_hooks.on_resume(AsyncMock()), + lifecycle_hooks.on_terminate(AsyncMock()), + ] + + assert len(lifecycle_hooks._spawn_hooks) == 1 + assert len(lifecycle_hooks._pause_hooks) == 1 + assert len(lifecycle_hooks._resume_hooks) == 1 + assert len(lifecycle_hooks._terminate_hooks) == 1 + + async def test_run_spawn_hooks(self, lifecycle_hooks: LifecycleHooks) -> None: + """Should run all spawn hooks.""" + hook1 = AsyncMock() + hook2 = AsyncMock() + lifecycle_hooks.on_spawn(hook1) + lifecycle_hooks.on_spawn(hook2) + + event = LifecycleEvent( + event_type="spawn", + project_id=uuid4(), + agent_instance_id=uuid4(), + ) + + await lifecycle_hooks.run_spawn_hooks(event) + + hook1.assert_called_once_with(event) + hook2.assert_called_once_with(event) + + async def test_hook_failure_doesnt_stop_others( + self, lifecycle_hooks: LifecycleHooks + ) -> None: + """Hook failure should not stop other hooks from running.""" + hook1 = AsyncMock(side_effect=ValueError("Oops")) + hook2 = AsyncMock() + lifecycle_hooks.on_pause(hook1) + lifecycle_hooks.on_pause(hook2) + + event = LifecycleEvent( + event_type="pause", + project_id=uuid4(), + agent_instance_id=uuid4(), + ) + + await lifecycle_hooks.run_pause_hooks(event) + + # hook2 should still be called even though hook1 failed + hook2.assert_called_once() + + +class TestAgentLifecycleManagerSpawn: + """Tests for AgentLifecycleManager.spawn.""" + + @patch("app.services.memory.integration.lifecycle.WorkingMemory") + async def test_spawn_creates_working_memory( + self, + mock_working_cls: MagicMock, + lifecycle_manager: AgentLifecycleManager, + ) -> None: + """Spawn should create working memory for session.""" + mock_working = AsyncMock() + mock_working.set = AsyncMock() + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + result = await lifecycle_manager.spawn( + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + ) + + assert result.success is True + assert result.event_type == "spawn" + mock_working_cls.for_session.assert_called_once() + + @patch("app.services.memory.integration.lifecycle.WorkingMemory") + async def test_spawn_with_initial_state( + self, + mock_working_cls: MagicMock, + lifecycle_manager: AgentLifecycleManager, + ) -> None: + """Spawn should populate initial state.""" + mock_working = AsyncMock() + mock_working.set = AsyncMock() + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + result = await lifecycle_manager.spawn( + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + initial_state={"key1": "value1", "key2": "value2"}, + ) + + assert result.success is True + assert result.data["initial_items"] == 2 + assert mock_working.set.call_count == 2 + + @patch("app.services.memory.integration.lifecycle.WorkingMemory") + async def test_spawn_runs_hooks( + self, + mock_working_cls: MagicMock, + lifecycle_manager: AgentLifecycleManager, + ) -> None: + """Spawn should run registered hooks.""" + mock_working = AsyncMock() + mock_working.set = AsyncMock() + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + hook = AsyncMock() + lifecycle_manager.hooks.on_spawn(hook) + + await lifecycle_manager.spawn( + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + ) + + hook.assert_called_once() + + +class TestAgentLifecycleManagerPause: + """Tests for AgentLifecycleManager.pause.""" + + @patch("app.services.memory.integration.lifecycle.WorkingMemory") + async def test_pause_creates_checkpoint( + self, + mock_working_cls: MagicMock, + lifecycle_manager: AgentLifecycleManager, + ) -> None: + """Pause should create checkpoint of working memory.""" + mock_working = AsyncMock() + mock_working.list_keys = AsyncMock(return_value=["key1", "key2"]) + mock_working.get = AsyncMock(return_value={"data": "test"}) + mock_working.set = AsyncMock() + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + result = await lifecycle_manager.pause( + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + checkpoint_id="ckpt-001", + ) + + assert result.success is True + assert result.event_type == "pause" + assert result.data["checkpoint_id"] == "ckpt-001" + assert result.data["items_saved"] == 2 + + # Should save checkpoint with state + mock_working.set.assert_called_once() + call_args = mock_working.set.call_args + # Check positional arg (first arg is key) + assert "__checkpoint__ckpt-001" in call_args[0][0] + + @patch("app.services.memory.integration.lifecycle.WorkingMemory") + async def test_pause_generates_checkpoint_id( + self, + mock_working_cls: MagicMock, + lifecycle_manager: AgentLifecycleManager, + ) -> None: + """Pause should generate checkpoint ID if not provided.""" + mock_working = AsyncMock() + mock_working.list_keys = AsyncMock(return_value=[]) + mock_working.set = AsyncMock() + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + result = await lifecycle_manager.pause( + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + ) + + assert result.success is True + assert "checkpoint_id" in result.data + assert result.data["checkpoint_id"].startswith("checkpoint_") + + +class TestAgentLifecycleManagerResume: + """Tests for AgentLifecycleManager.resume.""" + + @patch("app.services.memory.integration.lifecycle.WorkingMemory") + async def test_resume_restores_checkpoint( + self, + mock_working_cls: MagicMock, + lifecycle_manager: AgentLifecycleManager, + ) -> None: + """Resume should restore working memory from checkpoint.""" + checkpoint_data = { + "state": {"key1": "value1", "key2": "value2"}, + "timestamp": datetime.now(UTC).isoformat(), + "keys_count": 2, + } + + mock_working = AsyncMock() + mock_working.list_keys = AsyncMock(return_value=[]) + mock_working.get = AsyncMock(return_value=checkpoint_data) + mock_working.set = AsyncMock() + mock_working.delete = AsyncMock() + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + result = await lifecycle_manager.resume( + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + checkpoint_id="ckpt-001", + ) + + assert result.success is True + assert result.event_type == "resume" + assert result.data["items_restored"] == 2 + + @patch("app.services.memory.integration.lifecycle.WorkingMemory") + async def test_resume_checkpoint_not_found( + self, + mock_working_cls: MagicMock, + lifecycle_manager: AgentLifecycleManager, + ) -> None: + """Resume should fail if checkpoint not found.""" + mock_working = AsyncMock() + mock_working.get = AsyncMock(return_value=None) + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + result = await lifecycle_manager.resume( + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + checkpoint_id="nonexistent", + ) + + assert result.success is False + assert "not found" in result.message.lower() + + +class TestAgentLifecycleManagerTerminate: + """Tests for AgentLifecycleManager.terminate.""" + + @patch("app.services.memory.integration.lifecycle.EpisodicMemory") + @patch("app.services.memory.integration.lifecycle.WorkingMemory") + async def test_terminate_consolidates_to_episodic( + self, + mock_working_cls: MagicMock, + mock_episodic_cls: MagicMock, + lifecycle_manager: AgentLifecycleManager, + ) -> None: + """Terminate should consolidate working memory to episodic.""" + mock_working = AsyncMock() + mock_working.list_keys = AsyncMock(return_value=["key1", "key2"]) + mock_working.get = AsyncMock(return_value="value") + mock_working.delete = AsyncMock() + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + mock_episode = MagicMock() + mock_episode.id = uuid4() + + mock_episodic = AsyncMock() + mock_episodic.record_episode = AsyncMock(return_value=mock_episode) + mock_episodic_cls.create = AsyncMock(return_value=mock_episodic) + + result = await lifecycle_manager.terminate( + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + task_description="Completed task", + outcome=Outcome.SUCCESS, + ) + + assert result.success is True + assert result.event_type == "terminate" + assert result.data["episode_id"] == str(mock_episode.id) + assert result.data["state_items_consolidated"] == 2 + mock_episodic.record_episode.assert_called_once() + + @patch("app.services.memory.integration.lifecycle.WorkingMemory") + async def test_terminate_cleans_up_working( + self, + mock_working_cls: MagicMock, + lifecycle_manager: AgentLifecycleManager, + ) -> None: + """Terminate should clean up working memory.""" + mock_working = AsyncMock() + mock_working.list_keys = AsyncMock(return_value=["key1", "key2"]) + mock_working.get = AsyncMock(return_value="value") + mock_working.delete = AsyncMock() + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + result = await lifecycle_manager.terminate( + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + consolidate_to_episodic=False, + cleanup_working=True, + ) + + assert result.success is True + assert result.data["items_cleared"] == 2 + assert mock_working.delete.call_count == 2 + + +class TestAgentLifecycleManagerListCheckpoints: + """Tests for AgentLifecycleManager.list_checkpoints.""" + + @patch("app.services.memory.integration.lifecycle.WorkingMemory") + async def test_list_checkpoints( + self, + mock_working_cls: MagicMock, + lifecycle_manager: AgentLifecycleManager, + ) -> None: + """Should list available checkpoints.""" + mock_working = AsyncMock() + mock_working.list_keys = AsyncMock( + return_value=[ + "__checkpoint__ckpt-001", + "__checkpoint__ckpt-002", + "regular_key", + ] + ) + mock_working.get = AsyncMock( + return_value={ + "timestamp": "2024-01-01T00:00:00Z", + "keys_count": 5, + } + ) + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + checkpoints = await lifecycle_manager.list_checkpoints( + project_id=uuid4(), + agent_instance_id=uuid4(), + session_id="sess-123", + ) + + assert len(checkpoints) == 2 + assert checkpoints[0]["checkpoint_id"] == "ckpt-001" + assert checkpoints[0]["keys_count"] == 5 + + +class TestGetLifecycleManager: + """Tests for factory function.""" + + async def test_creates_instance(self) -> None: + """Factory should create AgentLifecycleManager instance.""" + mock_session = MagicMock() + + manager = await get_lifecycle_manager(mock_session) + + assert isinstance(manager, AgentLifecycleManager) + + async def test_with_custom_hooks(self) -> None: + """Factory should accept custom hooks.""" + mock_session = MagicMock() + custom_hooks = LifecycleHooks() + + manager = await get_lifecycle_manager(mock_session, hooks=custom_hooks) + + assert manager.hooks is custom_hooks