feat(memory): integrate memory system with context engine (#97)
## Changes ### New Context Type - Add MEMORY to ContextType enum for agent memory context - Create MemoryContext class with subtypes (working, episodic, semantic, procedural) - Factory methods: from_working_memory, from_episodic_memory, from_semantic_memory, from_procedural_memory ### Memory Context Source - MemoryContextSource service fetches relevant memories for context assembly - Configurable fetch limits per memory type - Parallel fetching from all memory types ### Agent Lifecycle Hooks - AgentLifecycleManager handles spawn, pause, resume, terminate events - spawn: Initialize working memory with optional initial state - pause: Create checkpoint of working memory - resume: Restore from checkpoint - terminate: Consolidate working memory to episodic memory - LifecycleHooks for custom extension points ### Context Engine Integration - Add memory_query parameter to assemble_context() - Add session_id and agent_type_id for memory scoping - Memory budget allocation (15% by default) - set_memory_source() for runtime configuration ### Tests - 48 new tests for MemoryContext, MemoryContextSource, and lifecycle hooks - All 108 memory-related tests passing - mypy and ruff checks passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
402
backend/app/services/memory/integration/context_source.py
Normal file
402
backend/app/services/memory/integration/context_source.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# app/services/memory/integration/context_source.py
|
||||
"""
|
||||
Memory Context Source.
|
||||
|
||||
Provides agent memory as a context source for the Context Engine.
|
||||
Retrieves relevant memories based on query and converts them to MemoryContext objects.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.context.types.memory import MemoryContext
|
||||
from app.services.memory.episodic import EpisodicMemory
|
||||
from app.services.memory.procedural import ProceduralMemory
|
||||
from app.services.memory.semantic import SemanticMemory
|
||||
from app.services.memory.working import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchConfig:
|
||||
"""Configuration for memory fetching."""
|
||||
|
||||
# Limits per memory type
|
||||
working_limit: int = 10
|
||||
episodic_limit: int = 10
|
||||
semantic_limit: int = 15
|
||||
procedural_limit: int = 5
|
||||
|
||||
# Time ranges
|
||||
episodic_days_back: int = 30
|
||||
min_relevance: float = 0.3
|
||||
|
||||
# Which memory types to include
|
||||
include_working: bool = True
|
||||
include_episodic: bool = True
|
||||
include_semantic: bool = True
|
||||
include_procedural: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryFetchResult:
|
||||
"""Result of memory fetch operation."""
|
||||
|
||||
contexts: list[MemoryContext]
|
||||
by_type: dict[str, int]
|
||||
fetch_time_ms: float
|
||||
query: str
|
||||
|
||||
|
||||
class MemoryContextSource:
|
||||
"""
|
||||
Source for memory context in the Context Engine.
|
||||
|
||||
This service retrieves relevant memories based on a query and
|
||||
converts them to MemoryContext objects for context assembly.
|
||||
It coordinates between all memory types (working, episodic,
|
||||
semantic, procedural) to provide a comprehensive memory context.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the memory context source.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
|
||||
# Lazy-initialized memory services
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session,
|
||||
self._embedding_generator,
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
async def fetch_context(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
config: MemoryFetchConfig | None = None,
|
||||
) -> MemoryFetchResult:
|
||||
"""
|
||||
Fetch relevant memories as context.
|
||||
|
||||
This is the main entry point for the Context Engine integration.
|
||||
It searches across all memory types and returns relevant memories
|
||||
as MemoryContext objects.
|
||||
|
||||
Args:
|
||||
query: Search query for finding relevant memories
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
agent_type_id: Optional agent type scope (for procedural)
|
||||
session_id: Optional session ID (for working memory)
|
||||
config: Optional fetch configuration
|
||||
|
||||
Returns:
|
||||
MemoryFetchResult with contexts and metadata
|
||||
"""
|
||||
config = config or MemoryFetchConfig()
|
||||
start_time = datetime.now(UTC)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
by_type: dict[str, int] = {
|
||||
"working": 0,
|
||||
"episodic": 0,
|
||||
"semantic": 0,
|
||||
"procedural": 0,
|
||||
}
|
||||
|
||||
# Fetch from working memory (session-scoped)
|
||||
if config.include_working and session_id:
|
||||
try:
|
||||
working_contexts = await self._fetch_working(
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.working_limit,
|
||||
)
|
||||
contexts.extend(working_contexts)
|
||||
by_type["working"] = len(working_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch working memory: {e}")
|
||||
|
||||
# Fetch from episodic memory
|
||||
if config.include_episodic:
|
||||
try:
|
||||
episodic_contexts = await self._fetch_episodic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_instance_id,
|
||||
limit=config.episodic_limit,
|
||||
days_back=config.episodic_days_back,
|
||||
)
|
||||
contexts.extend(episodic_contexts)
|
||||
by_type["episodic"] = len(episodic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch episodic memory: {e}")
|
||||
|
||||
# Fetch from semantic memory
|
||||
if config.include_semantic:
|
||||
try:
|
||||
semantic_contexts = await self._fetch_semantic(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=config.semantic_limit,
|
||||
min_relevance=config.min_relevance,
|
||||
)
|
||||
contexts.extend(semantic_contexts)
|
||||
by_type["semantic"] = len(semantic_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch semantic memory: {e}")
|
||||
|
||||
# Fetch from procedural memory
|
||||
if config.include_procedural:
|
||||
try:
|
||||
procedural_contexts = await self._fetch_procedural(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=config.procedural_limit,
|
||||
)
|
||||
contexts.extend(procedural_contexts)
|
||||
by_type["procedural"] = len(procedural_contexts)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch procedural memory: {e}")
|
||||
|
||||
# Sort by relevance
|
||||
contexts.sort(key=lambda c: c.relevance_score, reverse=True)
|
||||
|
||||
fetch_time = (datetime.now(UTC) - start_time).total_seconds() * 1000
|
||||
|
||||
logger.debug(
|
||||
f"Fetched {len(contexts)} memory contexts for query '{query[:50]}...' "
|
||||
f"in {fetch_time:.1f}ms"
|
||||
)
|
||||
|
||||
return MemoryFetchResult(
|
||||
contexts=contexts,
|
||||
by_type=by_type,
|
||||
fetch_time_ms=fetch_time,
|
||||
query=query,
|
||||
)
|
||||
|
||||
async def _fetch_working(
|
||||
self,
|
||||
query: str,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from working memory."""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
# Filter keys by query (simple substring match)
|
||||
query_lower = query.lower()
|
||||
matched_keys = [k for k in all_keys if query_lower in k.lower()]
|
||||
|
||||
# If no query match, include all keys (working memory is always relevant)
|
||||
if not matched_keys and query:
|
||||
matched_keys = all_keys
|
||||
|
||||
for key in matched_keys[:limit]:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
query=query,
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
async def _fetch_episodic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None,
|
||||
limit: int,
|
||||
days_back: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from episodic memory."""
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Search for similar episodes
|
||||
episodes = await episodic.search_similar(
|
||||
project_id=project_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
agent_instance_id=agent_instance_id,
|
||||
)
|
||||
|
||||
# Also get recent episodes if we didn't find enough
|
||||
if len(episodes) < limit // 2:
|
||||
since = datetime.now(UTC) - timedelta(days=days_back)
|
||||
recent = await episodic.get_recent(
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
since=since,
|
||||
)
|
||||
# Deduplicate by ID
|
||||
existing_ids = {e.id for e in episodes}
|
||||
for ep in recent:
|
||||
if ep.id not in existing_ids:
|
||||
episodes.append(ep)
|
||||
if len(episodes) >= limit:
|
||||
break
|
||||
|
||||
return [
|
||||
MemoryContext.from_episodic_memory(ep, query=query)
|
||||
for ep in episodes[:limit]
|
||||
]
|
||||
|
||||
async def _fetch_semantic(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
limit: int,
|
||||
min_relevance: float,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from semantic memory."""
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
facts = await semantic.search_facts(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
min_confidence=min_relevance,
|
||||
)
|
||||
|
||||
return [
|
||||
MemoryContext.from_semantic_memory(fact, query=query)
|
||||
for fact in facts
|
||||
]
|
||||
|
||||
async def _fetch_procedural(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
limit: int,
|
||||
) -> list[MemoryContext]:
|
||||
"""Fetch from procedural memory."""
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
procedures = await procedural.find_matching(
|
||||
context=query,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return [
|
||||
MemoryContext.from_procedural_memory(proc, query=query)
|
||||
for proc in procedures
|
||||
]
|
||||
|
||||
async def fetch_all_working(
|
||||
self,
|
||||
session_id: str,
|
||||
project_id: UUID,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> list[MemoryContext]:
|
||||
"""
|
||||
Fetch all working memory for a session.
|
||||
|
||||
Useful for including entire session state in context.
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
project_id: Project scope
|
||||
agent_instance_id: Optional agent instance scope
|
||||
|
||||
Returns:
|
||||
List of MemoryContext for all working memory items
|
||||
"""
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id),
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
|
||||
contexts: list[MemoryContext] = []
|
||||
all_keys = await working.list_keys()
|
||||
|
||||
for key in all_keys:
|
||||
value = await working.get(key)
|
||||
if value is not None:
|
||||
contexts.append(
|
||||
MemoryContext.from_working_memory(
|
||||
key=key,
|
||||
value=value,
|
||||
source=f"working:{session_id}",
|
||||
)
|
||||
)
|
||||
|
||||
return contexts
|
||||
|
||||
|
||||
# Factory function
|
||||
async def get_memory_context_source(
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> MemoryContextSource:
|
||||
"""Create a memory context source instance."""
|
||||
return MemoryContextSource(
|
||||
session=session,
|
||||
embedding_generator=embedding_generator,
|
||||
)
|
||||
Reference in New Issue
Block a user