# app/services/memory/indexing/index.py """ Memory Indexing. Provides multiple indexing strategies for efficient memory retrieval: - Vector embeddings for semantic search - Temporal index for time-based queries - Entity index for entity-based lookups - Outcome index for success/failure filtering """ import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta from typing import Any, TypeVar from uuid import UUID from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure logger = logging.getLogger(__name__) T = TypeVar("T", Episode, Fact, Procedure) def _utcnow() -> datetime: """Get current UTC time as timezone-aware datetime.""" return datetime.now(UTC) @dataclass class IndexEntry: """A single entry in an index.""" memory_id: UUID memory_type: MemoryType indexed_at: datetime = field(default_factory=_utcnow) metadata: dict[str, Any] = field(default_factory=dict) @dataclass class VectorIndexEntry(IndexEntry): """An entry with vector embedding.""" embedding: list[float] = field(default_factory=list) dimension: int = 0 def __post_init__(self) -> None: """Set dimension from embedding.""" if self.embedding: self.dimension = len(self.embedding) @dataclass class TemporalIndexEntry(IndexEntry): """An entry indexed by time.""" timestamp: datetime = field(default_factory=_utcnow) @dataclass class EntityIndexEntry(IndexEntry): """An entry indexed by entity.""" entity_type: str = "" entity_value: str = "" @dataclass class OutcomeIndexEntry(IndexEntry): """An entry indexed by outcome.""" outcome: Outcome = Outcome.SUCCESS class MemoryIndex[T](ABC): """Abstract base class for memory indices.""" @abstractmethod async def add(self, item: T) -> IndexEntry: """Add an item to the index.""" ... @abstractmethod async def remove(self, memory_id: UUID) -> bool: """Remove an item from the index.""" ... @abstractmethod async def search( self, query: Any, limit: int = 10, **kwargs: Any, ) -> list[IndexEntry]: """Search the index.""" ... @abstractmethod async def clear(self) -> int: """Clear all entries from the index.""" ... @abstractmethod async def count(self) -> int: """Get the number of entries in the index.""" ... class VectorIndex(MemoryIndex[T]): """ Vector-based index using embeddings for semantic similarity search. Uses cosine similarity for matching. """ def __init__(self, dimension: int = 1536) -> None: """ Initialize the vector index. Args: dimension: Embedding dimension (default 1536 for OpenAI) """ self._dimension = dimension self._entries: dict[UUID, VectorIndexEntry] = {} logger.info(f"Initialized VectorIndex with dimension={dimension}") async def add(self, item: T) -> VectorIndexEntry: """ Add an item to the vector index. Args: item: Memory item with embedding Returns: The created index entry """ embedding = getattr(item, "embedding", None) or [] entry = VectorIndexEntry( memory_id=item.id, memory_type=self._get_memory_type(item), embedding=embedding, dimension=len(embedding), ) self._entries[item.id] = entry logger.debug(f"Added {item.id} to vector index") return entry async def remove(self, memory_id: UUID) -> bool: """Remove an item from the vector index.""" if memory_id in self._entries: del self._entries[memory_id] logger.debug(f"Removed {memory_id} from vector index") return True return False async def search( # type: ignore[override] self, query: Any, limit: int = 10, min_similarity: float = 0.0, **kwargs: Any, ) -> list[VectorIndexEntry]: """ Search for similar items using vector similarity. Args: query: Query embedding vector limit: Maximum results to return min_similarity: Minimum similarity threshold (0-1) **kwargs: Additional filter parameters Returns: List of matching entries sorted by similarity """ if not isinstance(query, list) or not query: return [] results: list[tuple[float, VectorIndexEntry]] = [] for entry in self._entries.values(): if not entry.embedding: continue similarity = self._cosine_similarity(query, entry.embedding) if similarity >= min_similarity: results.append((similarity, entry)) # Sort by similarity descending results.sort(key=lambda x: x[0], reverse=True) # Apply memory type filter if provided memory_type = kwargs.get("memory_type") if memory_type: results = [(s, e) for s, e in results if e.memory_type == memory_type] # Store similarity in metadata for the returned entries # Use a copy of metadata to avoid mutating cached entries output = [] for similarity, entry in results[:limit]: # Create a shallow copy of the entry with updated metadata entry_with_score = VectorIndexEntry( memory_id=entry.memory_id, memory_type=entry.memory_type, embedding=entry.embedding, metadata={**entry.metadata, "similarity": similarity}, ) output.append(entry_with_score) logger.debug(f"Vector search returned {len(output)} results") return output async def clear(self) -> int: """Clear all entries from the index.""" count = len(self._entries) self._entries.clear() logger.info(f"Cleared {count} entries from vector index") return count async def count(self) -> int: """Get the number of entries in the index.""" return len(self._entries) def _cosine_similarity(self, a: list[float], b: list[float]) -> float: """Calculate cosine similarity between two vectors.""" if len(a) != len(b) or len(a) == 0: return 0.0 dot_product = sum(x * y for x, y in zip(a, b, strict=True)) norm_a = sum(x * x for x in a) ** 0.5 norm_b = sum(x * x for x in b) ** 0.5 if norm_a == 0 or norm_b == 0: return 0.0 return dot_product / (norm_a * norm_b) def _get_memory_type(self, item: T) -> MemoryType: """Get the memory type for an item.""" if isinstance(item, Episode): return MemoryType.EPISODIC elif isinstance(item, Fact): return MemoryType.SEMANTIC elif isinstance(item, Procedure): return MemoryType.PROCEDURAL return MemoryType.WORKING class TemporalIndex(MemoryIndex[T]): """ Time-based index for efficient temporal queries. Supports: - Range queries (between timestamps) - Recent items (within last N seconds/hours/days) - Oldest/newest sorting """ def __init__(self) -> None: """Initialize the temporal index.""" self._entries: dict[UUID, TemporalIndexEntry] = {} # Sorted list for efficient range queries self._sorted_entries: list[tuple[datetime, UUID]] = [] logger.info("Initialized TemporalIndex") async def add(self, item: T) -> TemporalIndexEntry: """ Add an item to the temporal index. Args: item: Memory item with timestamp Returns: The created index entry """ # Get timestamp from various possible fields timestamp = self._get_timestamp(item) entry = TemporalIndexEntry( memory_id=item.id, memory_type=self._get_memory_type(item), timestamp=timestamp, ) self._entries[item.id] = entry self._insert_sorted(timestamp, item.id) logger.debug(f"Added {item.id} to temporal index at {timestamp}") return entry async def remove(self, memory_id: UUID) -> bool: """Remove an item from the temporal index.""" if memory_id not in self._entries: return False self._entries.pop(memory_id) self._sorted_entries = [ (ts, mid) for ts, mid in self._sorted_entries if mid != memory_id ] logger.debug(f"Removed {memory_id} from temporal index") return True async def search( # type: ignore[override] self, query: Any, limit: int = 10, start_time: datetime | None = None, end_time: datetime | None = None, recent_seconds: float | None = None, order: str = "desc", **kwargs: Any, ) -> list[TemporalIndexEntry]: """ Search for items by time. Args: query: Ignored for temporal search limit: Maximum results to return start_time: Start of time range end_time: End of time range recent_seconds: Get items from last N seconds order: Sort order ("asc" or "desc") **kwargs: Additional filter parameters Returns: List of matching entries sorted by time """ if recent_seconds is not None: start_time = _utcnow() - timedelta(seconds=recent_seconds) end_time = _utcnow() # Filter by time range results: list[TemporalIndexEntry] = [] for entry in self._entries.values(): if start_time and entry.timestamp < start_time: continue if end_time and entry.timestamp > end_time: continue results.append(entry) # Apply memory type filter if provided memory_type = kwargs.get("memory_type") if memory_type: results = [e for e in results if e.memory_type == memory_type] # Sort by timestamp results.sort(key=lambda e: e.timestamp, reverse=(order == "desc")) logger.debug(f"Temporal search returned {min(len(results), limit)} results") return results[:limit] async def clear(self) -> int: """Clear all entries from the index.""" count = len(self._entries) self._entries.clear() self._sorted_entries.clear() logger.info(f"Cleared {count} entries from temporal index") return count async def count(self) -> int: """Get the number of entries in the index.""" return len(self._entries) def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None: """Insert entry maintaining sorted order.""" # Binary search insert for efficiency low, high = 0, len(self._sorted_entries) while low < high: mid = (low + high) // 2 if self._sorted_entries[mid][0] < timestamp: low = mid + 1 else: high = mid self._sorted_entries.insert(low, (timestamp, memory_id)) def _get_timestamp(self, item: T) -> datetime: """Get the relevant timestamp for an item.""" if hasattr(item, "occurred_at"): return item.occurred_at if hasattr(item, "first_learned"): return item.first_learned if hasattr(item, "last_used") and item.last_used: return item.last_used if hasattr(item, "created_at"): return item.created_at return _utcnow() def _get_memory_type(self, item: T) -> MemoryType: """Get the memory type for an item.""" if isinstance(item, Episode): return MemoryType.EPISODIC elif isinstance(item, Fact): return MemoryType.SEMANTIC elif isinstance(item, Procedure): return MemoryType.PROCEDURAL return MemoryType.WORKING class EntityIndex(MemoryIndex[T]): """ Entity-based index for lookups by entities mentioned in memories. Supports: - Single entity lookup - Multi-entity intersection - Entity type filtering """ def __init__(self) -> None: """Initialize the entity index.""" # Main storage self._entries: dict[UUID, EntityIndexEntry] = {} # Inverted index: entity -> set of memory IDs self._entity_to_memories: dict[str, set[UUID]] = {} # Memory to entities mapping self._memory_to_entities: dict[UUID, set[str]] = {} logger.info("Initialized EntityIndex") async def add(self, item: T) -> EntityIndexEntry: """ Add an item to the entity index. Args: item: Memory item with entity information Returns: The created index entry """ entities = self._extract_entities(item) # Create entry for the primary entity (or first one) primary_entity = entities[0] if entities else ("unknown", "unknown") entry = EntityIndexEntry( memory_id=item.id, memory_type=self._get_memory_type(item), entity_type=primary_entity[0], entity_value=primary_entity[1], ) self._entries[item.id] = entry # Update inverted indices entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities} self._memory_to_entities[item.id] = entity_keys for entity_key in entity_keys: if entity_key not in self._entity_to_memories: self._entity_to_memories[entity_key] = set() self._entity_to_memories[entity_key].add(item.id) logger.debug(f"Added {item.id} to entity index with {len(entities)} entities") return entry async def remove(self, memory_id: UUID) -> bool: """Remove an item from the entity index.""" if memory_id not in self._entries: return False # Remove from inverted index if memory_id in self._memory_to_entities: for entity_key in self._memory_to_entities[memory_id]: if entity_key in self._entity_to_memories: self._entity_to_memories[entity_key].discard(memory_id) if not self._entity_to_memories[entity_key]: del self._entity_to_memories[entity_key] del self._memory_to_entities[memory_id] del self._entries[memory_id] logger.debug(f"Removed {memory_id} from entity index") return True async def search( # type: ignore[override] self, query: Any, limit: int = 10, entity_type: str | None = None, entity_value: str | None = None, entities: list[tuple[str, str]] | None = None, match_all: bool = False, **kwargs: Any, ) -> list[EntityIndexEntry]: """ Search for items by entity. Args: query: Entity value to search (if entity_type not specified) limit: Maximum results to return entity_type: Type of entity to filter entity_value: Specific entity value entities: List of (type, value) tuples to match match_all: If True, require all entities to match **kwargs: Additional filter parameters Returns: List of matching entries """ matching_ids: set[UUID] | None = None # Handle single entity query if entity_type and entity_value: entities = [(entity_type, entity_value)] elif entity_value is None and isinstance(query, str): # Search across all entity types entity_value = query if entities: for etype, evalue in entities: entity_key = f"{etype}:{evalue}" if entity_key in self._entity_to_memories: ids = self._entity_to_memories[entity_key] if matching_ids is None: matching_ids = ids.copy() elif match_all: matching_ids &= ids else: matching_ids |= ids elif match_all: # Required entity not found matching_ids = set() break elif entity_value: # Search for value across all types matching_ids = set() for entity_key, ids in self._entity_to_memories.items(): if entity_value.lower() in entity_key.lower(): matching_ids |= ids if matching_ids is None: matching_ids = set(self._entries.keys()) # Apply memory type filter if provided memory_type = kwargs.get("memory_type") results = [] for mid in matching_ids: if mid in self._entries: entry = self._entries[mid] if memory_type and entry.memory_type != memory_type: continue results.append(entry) logger.debug(f"Entity search returned {min(len(results), limit)} results") return results[:limit] async def clear(self) -> int: """Clear all entries from the index.""" count = len(self._entries) self._entries.clear() self._entity_to_memories.clear() self._memory_to_entities.clear() logger.info(f"Cleared {count} entries from entity index") return count async def count(self) -> int: """Get the number of entries in the index.""" return len(self._entries) async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]: """Get all entities for a memory item.""" if memory_id not in self._memory_to_entities: return [] entities = [] for entity_key in self._memory_to_entities[memory_id]: if ":" in entity_key: etype, evalue = entity_key.split(":", 1) entities.append((etype, evalue)) return entities def _extract_entities(self, item: T) -> list[tuple[str, str]]: """Extract entities from a memory item.""" entities: list[tuple[str, str]] = [] if isinstance(item, Episode): # Extract from task type and context entities.append(("task_type", item.task_type)) if item.project_id: entities.append(("project", str(item.project_id))) if item.agent_instance_id: entities.append(("agent_instance", str(item.agent_instance_id))) if item.agent_type_id: entities.append(("agent_type", str(item.agent_type_id))) elif isinstance(item, Fact): # Subject and object are entities entities.append(("subject", item.subject)) entities.append(("object", item.object)) if item.project_id: entities.append(("project", str(item.project_id))) elif isinstance(item, Procedure): entities.append(("procedure", item.name)) if item.project_id: entities.append(("project", str(item.project_id))) if item.agent_type_id: entities.append(("agent_type", str(item.agent_type_id))) return entities def _get_memory_type(self, item: T) -> MemoryType: """Get the memory type for an item.""" if isinstance(item, Episode): return MemoryType.EPISODIC elif isinstance(item, Fact): return MemoryType.SEMANTIC elif isinstance(item, Procedure): return MemoryType.PROCEDURAL return MemoryType.WORKING class OutcomeIndex(MemoryIndex[T]): """ Outcome-based index for filtering by success/failure. Primarily used for episodes and procedures. """ def __init__(self) -> None: """Initialize the outcome index.""" self._entries: dict[UUID, OutcomeIndexEntry] = {} # Inverted index by outcome self._outcome_to_memories: dict[Outcome, set[UUID]] = { Outcome.SUCCESS: set(), Outcome.FAILURE: set(), Outcome.PARTIAL: set(), } logger.info("Initialized OutcomeIndex") async def add(self, item: T) -> OutcomeIndexEntry: """ Add an item to the outcome index. Args: item: Memory item with outcome information Returns: The created index entry """ outcome = self._get_outcome(item) entry = OutcomeIndexEntry( memory_id=item.id, memory_type=self._get_memory_type(item), outcome=outcome, ) self._entries[item.id] = entry self._outcome_to_memories[outcome].add(item.id) logger.debug(f"Added {item.id} to outcome index with {outcome.value}") return entry async def remove(self, memory_id: UUID) -> bool: """Remove an item from the outcome index.""" if memory_id not in self._entries: return False entry = self._entries.pop(memory_id) self._outcome_to_memories[entry.outcome].discard(memory_id) logger.debug(f"Removed {memory_id} from outcome index") return True async def search( # type: ignore[override] self, query: Any, limit: int = 10, outcome: Outcome | None = None, outcomes: list[Outcome] | None = None, **kwargs: Any, ) -> list[OutcomeIndexEntry]: """ Search for items by outcome. Args: query: Ignored for outcome search limit: Maximum results to return outcome: Single outcome to filter outcomes: Multiple outcomes to filter (OR) **kwargs: Additional filter parameters Returns: List of matching entries """ if outcome: outcomes = [outcome] if outcomes: matching_ids: set[UUID] = set() for o in outcomes: matching_ids |= self._outcome_to_memories.get(o, set()) else: matching_ids = set(self._entries.keys()) # Apply memory type filter if provided memory_type = kwargs.get("memory_type") results = [] for mid in matching_ids: if mid in self._entries: entry = self._entries[mid] if memory_type and entry.memory_type != memory_type: continue results.append(entry) logger.debug(f"Outcome search returned {min(len(results), limit)} results") return results[:limit] async def clear(self) -> int: """Clear all entries from the index.""" count = len(self._entries) self._entries.clear() for outcome in self._outcome_to_memories: self._outcome_to_memories[outcome].clear() logger.info(f"Cleared {count} entries from outcome index") return count async def count(self) -> int: """Get the number of entries in the index.""" return len(self._entries) async def get_outcome_stats(self) -> dict[Outcome, int]: """Get statistics on outcomes.""" return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()} def _get_outcome(self, item: T) -> Outcome: """Get the outcome for an item.""" if isinstance(item, Episode): return item.outcome elif isinstance(item, Procedure): # Derive from success rate if item.success_rate >= 0.8: return Outcome.SUCCESS elif item.success_rate <= 0.2: return Outcome.FAILURE return Outcome.PARTIAL return Outcome.SUCCESS def _get_memory_type(self, item: T) -> MemoryType: """Get the memory type for an item.""" if isinstance(item, Episode): return MemoryType.EPISODIC elif isinstance(item, Fact): return MemoryType.SEMANTIC elif isinstance(item, Procedure): return MemoryType.PROCEDURAL return MemoryType.WORKING @dataclass class MemoryIndexer: """ Unified indexer that manages all index types. Provides a single interface for indexing and searching across multiple index types. """ vector_index: VectorIndex[Any] = field(default_factory=VectorIndex) temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex) entity_index: EntityIndex[Any] = field(default_factory=EntityIndex) outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex) async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]: """ Index an item across all applicable indices. Args: item: Memory item to index Returns: Dictionary of index type to entry """ results: dict[str, IndexEntry] = {} # Vector index (if embedding present) if getattr(item, "embedding", None): results["vector"] = await self.vector_index.add(item) # Temporal index results["temporal"] = await self.temporal_index.add(item) # Entity index results["entity"] = await self.entity_index.add(item) # Outcome index (for episodes and procedures) if isinstance(item, (Episode, Procedure)): results["outcome"] = await self.outcome_index.add(item) logger.info( f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}" ) return results async def remove(self, memory_id: UUID) -> dict[str, bool]: """ Remove an item from all indices. Args: memory_id: ID of the memory to remove Returns: Dictionary of index type to removal success """ results = { "vector": await self.vector_index.remove(memory_id), "temporal": await self.temporal_index.remove(memory_id), "entity": await self.entity_index.remove(memory_id), "outcome": await self.outcome_index.remove(memory_id), } removed_from = [k for k, v in results.items() if v] if removed_from: logger.info(f"Removed {memory_id} from indices: {removed_from}") return results async def clear_all(self) -> dict[str, int]: """ Clear all indices. Returns: Dictionary of index type to count cleared """ return { "vector": await self.vector_index.clear(), "temporal": await self.temporal_index.clear(), "entity": await self.entity_index.clear(), "outcome": await self.outcome_index.clear(), } async def get_stats(self) -> dict[str, int]: """ Get statistics for all indices. Returns: Dictionary of index type to entry count """ return { "vector": await self.vector_index.count(), "temporal": await self.temporal_index.count(), "entity": await self.entity_index.count(), "outcome": await self.outcome_index.count(), } # Singleton indexer instance _indexer: MemoryIndexer | None = None def get_memory_indexer() -> MemoryIndexer: """Get the singleton memory indexer instance.""" global _indexer if _indexer is None: _indexer = MemoryIndexer() return _indexer