diff --git a/backend/app/services/memory/indexing/__init__.py b/backend/app/services/memory/indexing/__init__.py index 5e2e3d8..e87ce71 100644 --- a/backend/app/services/memory/indexing/__init__.py +++ b/backend/app/services/memory/indexing/__init__.py @@ -1,7 +1,56 @@ +# app/services/memory/indexing/__init__.py """ -Memory Indexing +Memory Indexing & Retrieval. -Vector embeddings and retrieval engine for memory search. +Provides vector embeddings and multiple index types for efficient memory search: +- Vector index for semantic similarity +- Temporal index for time-based queries +- Entity index for entity lookups +- Outcome index for success/failure filtering """ -# Will be populated in #94 +from .index import ( + EntityIndex, + EntityIndexEntry, + IndexEntry, + MemoryIndex, + MemoryIndexer, + OutcomeIndex, + OutcomeIndexEntry, + TemporalIndex, + TemporalIndexEntry, + VectorIndex, + VectorIndexEntry, + get_memory_indexer, +) +from .retrieval import ( + CacheEntry, + RelevanceScorer, + RetrievalCache, + RetrievalEngine, + RetrievalQuery, + ScoredResult, + get_retrieval_engine, +) + +__all__ = [ + "CacheEntry", + "EntityIndex", + "EntityIndexEntry", + "IndexEntry", + "MemoryIndex", + "MemoryIndexer", + "OutcomeIndex", + "OutcomeIndexEntry", + "RelevanceScorer", + "RetrievalCache", + "RetrievalEngine", + "RetrievalQuery", + "ScoredResult", + "TemporalIndex", + "TemporalIndexEntry", + "VectorIndex", + "VectorIndexEntry", + "get_memory_indexer", + "get_retrieval_engine", +] diff --git a/backend/app/services/memory/indexing/index.py b/backend/app/services/memory/indexing/index.py new file mode 100644 index 0000000..9f3a02c --- /dev/null +++ b/backend/app/services/memory/indexing/index.py @@ -0,0 +1,851 @@ +# app/services/memory/indexing/index.py +""" +Memory Indexing. + +Provides multiple indexing strategies for efficient memory retrieval: +- Vector embeddings for semantic search +- Temporal index for time-based queries +- Entity index for entity-based lookups +- Outcome index for success/failure filtering +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import UTC, datetime, timedelta +from typing import Any, TypeVar +from uuid import UUID + +from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure + +logger = logging.getLogger(__name__) + +T = TypeVar("T", Episode, Fact, Procedure) + + +def _utcnow() -> datetime: + """Get current UTC time as timezone-aware datetime.""" + return datetime.now(UTC) + + +@dataclass +class IndexEntry: + """A single entry in an index.""" + + memory_id: UUID + memory_type: MemoryType + indexed_at: datetime = field(default_factory=_utcnow) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class VectorIndexEntry(IndexEntry): + """An entry with vector embedding.""" + + embedding: list[float] = field(default_factory=list) + dimension: int = 0 + + def __post_init__(self) -> None: + """Set dimension from embedding.""" + if self.embedding: + self.dimension = len(self.embedding) + + +@dataclass +class TemporalIndexEntry(IndexEntry): + """An entry indexed by time.""" + + timestamp: datetime = field(default_factory=_utcnow) + + +@dataclass +class EntityIndexEntry(IndexEntry): + """An entry indexed by entity.""" + + entity_type: str = "" + entity_value: str = "" + + +@dataclass +class OutcomeIndexEntry(IndexEntry): + """An entry indexed by outcome.""" + + outcome: Outcome = Outcome.SUCCESS + + +class MemoryIndex[T](ABC): + """Abstract base class for memory indices.""" + + @abstractmethod + async def add(self, item: T) -> IndexEntry: + """Add an item to the index.""" + ... + + @abstractmethod + async def remove(self, memory_id: UUID) -> bool: + """Remove an item from the index.""" + ... + + @abstractmethod + async def search( + self, + query: Any, + limit: int = 10, + **kwargs: Any, + ) -> list[IndexEntry]: + """Search the index.""" + ... + + @abstractmethod + async def clear(self) -> int: + """Clear all entries from the index.""" + ... + + @abstractmethod + async def count(self) -> int: + """Get the number of entries in the index.""" + ... + + +class VectorIndex(MemoryIndex[T]): + """ + Vector-based index using embeddings for semantic similarity search. + + Uses cosine similarity for matching. + """ + + def __init__(self, dimension: int = 1536) -> None: + """ + Initialize the vector index. + + Args: + dimension: Embedding dimension (default 1536 for OpenAI) + """ + self._dimension = dimension + self._entries: dict[UUID, VectorIndexEntry] = {} + logger.info(f"Initialized VectorIndex with dimension={dimension}") + + async def add(self, item: T) -> VectorIndexEntry: + """ + Add an item to the vector index. + + Args: + item: Memory item with embedding + + Returns: + The created index entry + """ + embedding = getattr(item, "embedding", None) or [] + + entry = VectorIndexEntry( + memory_id=item.id, + memory_type=self._get_memory_type(item), + embedding=embedding, + dimension=len(embedding), + ) + + self._entries[item.id] = entry + logger.debug(f"Added {item.id} to vector index") + return entry + + async def remove(self, memory_id: UUID) -> bool: + """Remove an item from the vector index.""" + if memory_id in self._entries: + del self._entries[memory_id] + logger.debug(f"Removed {memory_id} from vector index") + return True + return False + + async def search( # type: ignore[override] + self, + query: Any, + limit: int = 10, + min_similarity: float = 0.0, + **kwargs: Any, + ) -> list[VectorIndexEntry]: + """ + Search for similar items using vector similarity. + + Args: + query: Query embedding vector + limit: Maximum results to return + min_similarity: Minimum similarity threshold (0-1) + **kwargs: Additional filter parameters + + Returns: + List of matching entries sorted by similarity + """ + if not isinstance(query, list) or not query: + return [] + + results: list[tuple[float, VectorIndexEntry]] = [] + + for entry in self._entries.values(): + if not entry.embedding: + continue + + similarity = self._cosine_similarity(query, entry.embedding) + if similarity >= min_similarity: + results.append((similarity, entry)) + + # Sort by similarity descending + results.sort(key=lambda x: x[0], reverse=True) + + # Apply memory type filter if provided + memory_type = kwargs.get("memory_type") + if memory_type: + results = [(s, e) for s, e in results if e.memory_type == memory_type] + + # Store similarity in metadata for the returned entries + output = [] + for similarity, entry in results[:limit]: + entry.metadata["similarity"] = similarity + output.append(entry) + + logger.debug(f"Vector search returned {len(output)} results") + return output + + async def clear(self) -> int: + """Clear all entries from the index.""" + count = len(self._entries) + self._entries.clear() + logger.info(f"Cleared {count} entries from vector index") + return count + + async def count(self) -> int: + """Get the number of entries in the index.""" + return len(self._entries) + + def _cosine_similarity(self, a: list[float], b: list[float]) -> float: + """Calculate cosine similarity between two vectors.""" + if len(a) != len(b) or len(a) == 0: + return 0.0 + + dot_product = sum(x * y for x, y in zip(a, b, strict=True)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 + + if norm_a == 0 or norm_b == 0: + return 0.0 + + return dot_product / (norm_a * norm_b) + + def _get_memory_type(self, item: T) -> MemoryType: + """Get the memory type for an item.""" + if isinstance(item, Episode): + return MemoryType.EPISODIC + elif isinstance(item, Fact): + return MemoryType.SEMANTIC + elif isinstance(item, Procedure): + return MemoryType.PROCEDURAL + return MemoryType.WORKING + + +class TemporalIndex(MemoryIndex[T]): + """ + Time-based index for efficient temporal queries. + + Supports: + - Range queries (between timestamps) + - Recent items (within last N seconds/hours/days) + - Oldest/newest sorting + """ + + def __init__(self) -> None: + """Initialize the temporal index.""" + self._entries: dict[UUID, TemporalIndexEntry] = {} + # Sorted list for efficient range queries + self._sorted_entries: list[tuple[datetime, UUID]] = [] + logger.info("Initialized TemporalIndex") + + async def add(self, item: T) -> TemporalIndexEntry: + """ + Add an item to the temporal index. + + Args: + item: Memory item with timestamp + + Returns: + The created index entry + """ + # Get timestamp from various possible fields + timestamp = self._get_timestamp(item) + + entry = TemporalIndexEntry( + memory_id=item.id, + memory_type=self._get_memory_type(item), + timestamp=timestamp, + ) + + self._entries[item.id] = entry + self._insert_sorted(timestamp, item.id) + + logger.debug(f"Added {item.id} to temporal index at {timestamp}") + return entry + + async def remove(self, memory_id: UUID) -> bool: + """Remove an item from the temporal index.""" + if memory_id not in self._entries: + return False + + self._entries.pop(memory_id) + self._sorted_entries = [ + (ts, mid) for ts, mid in self._sorted_entries if mid != memory_id + ] + + logger.debug(f"Removed {memory_id} from temporal index") + return True + + async def search( # type: ignore[override] + self, + query: Any, + limit: int = 10, + start_time: datetime | None = None, + end_time: datetime | None = None, + recent_seconds: float | None = None, + order: str = "desc", + **kwargs: Any, + ) -> list[TemporalIndexEntry]: + """ + Search for items by time. + + Args: + query: Ignored for temporal search + limit: Maximum results to return + start_time: Start of time range + end_time: End of time range + recent_seconds: Get items from last N seconds + order: Sort order ("asc" or "desc") + **kwargs: Additional filter parameters + + Returns: + List of matching entries sorted by time + """ + if recent_seconds is not None: + start_time = _utcnow() - timedelta(seconds=recent_seconds) + end_time = _utcnow() + + # Filter by time range + results: list[TemporalIndexEntry] = [] + for entry in self._entries.values(): + if start_time and entry.timestamp < start_time: + continue + if end_time and entry.timestamp > end_time: + continue + results.append(entry) + + # Apply memory type filter if provided + memory_type = kwargs.get("memory_type") + if memory_type: + results = [e for e in results if e.memory_type == memory_type] + + # Sort by timestamp + results.sort(key=lambda e: e.timestamp, reverse=(order == "desc")) + + logger.debug(f"Temporal search returned {min(len(results), limit)} results") + return results[:limit] + + async def clear(self) -> int: + """Clear all entries from the index.""" + count = len(self._entries) + self._entries.clear() + self._sorted_entries.clear() + logger.info(f"Cleared {count} entries from temporal index") + return count + + async def count(self) -> int: + """Get the number of entries in the index.""" + return len(self._entries) + + def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None: + """Insert entry maintaining sorted order.""" + # Binary search insert for efficiency + low, high = 0, len(self._sorted_entries) + while low < high: + mid = (low + high) // 2 + if self._sorted_entries[mid][0] < timestamp: + low = mid + 1 + else: + high = mid + self._sorted_entries.insert(low, (timestamp, memory_id)) + + def _get_timestamp(self, item: T) -> datetime: + """Get the relevant timestamp for an item.""" + if hasattr(item, "occurred_at"): + return item.occurred_at + if hasattr(item, "first_learned"): + return item.first_learned + if hasattr(item, "last_used") and item.last_used: + return item.last_used + if hasattr(item, "created_at"): + return item.created_at + return _utcnow() + + def _get_memory_type(self, item: T) -> MemoryType: + """Get the memory type for an item.""" + if isinstance(item, Episode): + return MemoryType.EPISODIC + elif isinstance(item, Fact): + return MemoryType.SEMANTIC + elif isinstance(item, Procedure): + return MemoryType.PROCEDURAL + return MemoryType.WORKING + + +class EntityIndex(MemoryIndex[T]): + """ + Entity-based index for lookups by entities mentioned in memories. + + Supports: + - Single entity lookup + - Multi-entity intersection + - Entity type filtering + """ + + def __init__(self) -> None: + """Initialize the entity index.""" + # Main storage + self._entries: dict[UUID, EntityIndexEntry] = {} + # Inverted index: entity -> set of memory IDs + self._entity_to_memories: dict[str, set[UUID]] = {} + # Memory to entities mapping + self._memory_to_entities: dict[UUID, set[str]] = {} + logger.info("Initialized EntityIndex") + + async def add(self, item: T) -> EntityIndexEntry: + """ + Add an item to the entity index. + + Args: + item: Memory item with entity information + + Returns: + The created index entry + """ + entities = self._extract_entities(item) + + # Create entry for the primary entity (or first one) + primary_entity = entities[0] if entities else ("unknown", "unknown") + + entry = EntityIndexEntry( + memory_id=item.id, + memory_type=self._get_memory_type(item), + entity_type=primary_entity[0], + entity_value=primary_entity[1], + ) + + self._entries[item.id] = entry + + # Update inverted indices + entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities} + self._memory_to_entities[item.id] = entity_keys + + for entity_key in entity_keys: + if entity_key not in self._entity_to_memories: + self._entity_to_memories[entity_key] = set() + self._entity_to_memories[entity_key].add(item.id) + + logger.debug(f"Added {item.id} to entity index with {len(entities)} entities") + return entry + + async def remove(self, memory_id: UUID) -> bool: + """Remove an item from the entity index.""" + if memory_id not in self._entries: + return False + + # Remove from inverted index + if memory_id in self._memory_to_entities: + for entity_key in self._memory_to_entities[memory_id]: + if entity_key in self._entity_to_memories: + self._entity_to_memories[entity_key].discard(memory_id) + if not self._entity_to_memories[entity_key]: + del self._entity_to_memories[entity_key] + del self._memory_to_entities[memory_id] + + del self._entries[memory_id] + logger.debug(f"Removed {memory_id} from entity index") + return True + + async def search( # type: ignore[override] + self, + query: Any, + limit: int = 10, + entity_type: str | None = None, + entity_value: str | None = None, + entities: list[tuple[str, str]] | None = None, + match_all: bool = False, + **kwargs: Any, + ) -> list[EntityIndexEntry]: + """ + Search for items by entity. + + Args: + query: Entity value to search (if entity_type not specified) + limit: Maximum results to return + entity_type: Type of entity to filter + entity_value: Specific entity value + entities: List of (type, value) tuples to match + match_all: If True, require all entities to match + **kwargs: Additional filter parameters + + Returns: + List of matching entries + """ + matching_ids: set[UUID] | None = None + + # Handle single entity query + if entity_type and entity_value: + entities = [(entity_type, entity_value)] + elif entity_value is None and isinstance(query, str): + # Search across all entity types + entity_value = query + + if entities: + for etype, evalue in entities: + entity_key = f"{etype}:{evalue}" + if entity_key in self._entity_to_memories: + ids = self._entity_to_memories[entity_key] + if matching_ids is None: + matching_ids = ids.copy() + elif match_all: + matching_ids &= ids + else: + matching_ids |= ids + elif match_all: + # Required entity not found + matching_ids = set() + break + elif entity_value: + # Search for value across all types + matching_ids = set() + for entity_key, ids in self._entity_to_memories.items(): + if entity_value.lower() in entity_key.lower(): + matching_ids |= ids + + if matching_ids is None: + matching_ids = set(self._entries.keys()) + + # Apply memory type filter if provided + memory_type = kwargs.get("memory_type") + results = [] + for mid in matching_ids: + if mid in self._entries: + entry = self._entries[mid] + if memory_type and entry.memory_type != memory_type: + continue + results.append(entry) + + logger.debug(f"Entity search returned {min(len(results), limit)} results") + return results[:limit] + + async def clear(self) -> int: + """Clear all entries from the index.""" + count = len(self._entries) + self._entries.clear() + self._entity_to_memories.clear() + self._memory_to_entities.clear() + logger.info(f"Cleared {count} entries from entity index") + return count + + async def count(self) -> int: + """Get the number of entries in the index.""" + return len(self._entries) + + async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]: + """Get all entities for a memory item.""" + if memory_id not in self._memory_to_entities: + return [] + + entities = [] + for entity_key in self._memory_to_entities[memory_id]: + if ":" in entity_key: + etype, evalue = entity_key.split(":", 1) + entities.append((etype, evalue)) + return entities + + def _extract_entities(self, item: T) -> list[tuple[str, str]]: + """Extract entities from a memory item.""" + entities: list[tuple[str, str]] = [] + + if isinstance(item, Episode): + # Extract from task type and context + entities.append(("task_type", item.task_type)) + if item.project_id: + entities.append(("project", str(item.project_id))) + if item.agent_instance_id: + entities.append(("agent_instance", str(item.agent_instance_id))) + if item.agent_type_id: + entities.append(("agent_type", str(item.agent_type_id))) + + elif isinstance(item, Fact): + # Subject and object are entities + entities.append(("subject", item.subject)) + entities.append(("object", item.object)) + if item.project_id: + entities.append(("project", str(item.project_id))) + + elif isinstance(item, Procedure): + entities.append(("procedure", item.name)) + if item.project_id: + entities.append(("project", str(item.project_id))) + if item.agent_type_id: + entities.append(("agent_type", str(item.agent_type_id))) + + return entities + + def _get_memory_type(self, item: T) -> MemoryType: + """Get the memory type for an item.""" + if isinstance(item, Episode): + return MemoryType.EPISODIC + elif isinstance(item, Fact): + return MemoryType.SEMANTIC + elif isinstance(item, Procedure): + return MemoryType.PROCEDURAL + return MemoryType.WORKING + + +class OutcomeIndex(MemoryIndex[T]): + """ + Outcome-based index for filtering by success/failure. + + Primarily used for episodes and procedures. + """ + + def __init__(self) -> None: + """Initialize the outcome index.""" + self._entries: dict[UUID, OutcomeIndexEntry] = {} + # Inverted index by outcome + self._outcome_to_memories: dict[Outcome, set[UUID]] = { + Outcome.SUCCESS: set(), + Outcome.FAILURE: set(), + Outcome.PARTIAL: set(), + } + logger.info("Initialized OutcomeIndex") + + async def add(self, item: T) -> OutcomeIndexEntry: + """ + Add an item to the outcome index. + + Args: + item: Memory item with outcome information + + Returns: + The created index entry + """ + outcome = self._get_outcome(item) + + entry = OutcomeIndexEntry( + memory_id=item.id, + memory_type=self._get_memory_type(item), + outcome=outcome, + ) + + self._entries[item.id] = entry + self._outcome_to_memories[outcome].add(item.id) + + logger.debug(f"Added {item.id} to outcome index with {outcome.value}") + return entry + + async def remove(self, memory_id: UUID) -> bool: + """Remove an item from the outcome index.""" + if memory_id not in self._entries: + return False + + entry = self._entries.pop(memory_id) + self._outcome_to_memories[entry.outcome].discard(memory_id) + + logger.debug(f"Removed {memory_id} from outcome index") + return True + + async def search( # type: ignore[override] + self, + query: Any, + limit: int = 10, + outcome: Outcome | None = None, + outcomes: list[Outcome] | None = None, + **kwargs: Any, + ) -> list[OutcomeIndexEntry]: + """ + Search for items by outcome. + + Args: + query: Ignored for outcome search + limit: Maximum results to return + outcome: Single outcome to filter + outcomes: Multiple outcomes to filter (OR) + **kwargs: Additional filter parameters + + Returns: + List of matching entries + """ + if outcome: + outcomes = [outcome] + + if outcomes: + matching_ids: set[UUID] = set() + for o in outcomes: + matching_ids |= self._outcome_to_memories.get(o, set()) + else: + matching_ids = set(self._entries.keys()) + + # Apply memory type filter if provided + memory_type = kwargs.get("memory_type") + results = [] + for mid in matching_ids: + if mid in self._entries: + entry = self._entries[mid] + if memory_type and entry.memory_type != memory_type: + continue + results.append(entry) + + logger.debug(f"Outcome search returned {min(len(results), limit)} results") + return results[:limit] + + async def clear(self) -> int: + """Clear all entries from the index.""" + count = len(self._entries) + self._entries.clear() + for outcome in self._outcome_to_memories: + self._outcome_to_memories[outcome].clear() + logger.info(f"Cleared {count} entries from outcome index") + return count + + async def count(self) -> int: + """Get the number of entries in the index.""" + return len(self._entries) + + async def get_outcome_stats(self) -> dict[Outcome, int]: + """Get statistics on outcomes.""" + return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()} + + def _get_outcome(self, item: T) -> Outcome: + """Get the outcome for an item.""" + if isinstance(item, Episode): + return item.outcome + elif isinstance(item, Procedure): + # Derive from success rate + if item.success_rate >= 0.8: + return Outcome.SUCCESS + elif item.success_rate <= 0.2: + return Outcome.FAILURE + return Outcome.PARTIAL + return Outcome.SUCCESS + + def _get_memory_type(self, item: T) -> MemoryType: + """Get the memory type for an item.""" + if isinstance(item, Episode): + return MemoryType.EPISODIC + elif isinstance(item, Fact): + return MemoryType.SEMANTIC + elif isinstance(item, Procedure): + return MemoryType.PROCEDURAL + return MemoryType.WORKING + + +@dataclass +class MemoryIndexer: + """ + Unified indexer that manages all index types. + + Provides a single interface for indexing and searching across + multiple index types. + """ + + vector_index: VectorIndex[Any] = field(default_factory=VectorIndex) + temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex) + entity_index: EntityIndex[Any] = field(default_factory=EntityIndex) + outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex) + + async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]: + """ + Index an item across all applicable indices. + + Args: + item: Memory item to index + + Returns: + Dictionary of index type to entry + """ + results: dict[str, IndexEntry] = {} + + # Vector index (if embedding present) + if getattr(item, "embedding", None): + results["vector"] = await self.vector_index.add(item) + + # Temporal index + results["temporal"] = await self.temporal_index.add(item) + + # Entity index + results["entity"] = await self.entity_index.add(item) + + # Outcome index (for episodes and procedures) + if isinstance(item, (Episode, Procedure)): + results["outcome"] = await self.outcome_index.add(item) + + logger.info( + f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}" + ) + return results + + async def remove(self, memory_id: UUID) -> dict[str, bool]: + """ + Remove an item from all indices. + + Args: + memory_id: ID of the memory to remove + + Returns: + Dictionary of index type to removal success + """ + results = { + "vector": await self.vector_index.remove(memory_id), + "temporal": await self.temporal_index.remove(memory_id), + "entity": await self.entity_index.remove(memory_id), + "outcome": await self.outcome_index.remove(memory_id), + } + + removed_from = [k for k, v in results.items() if v] + if removed_from: + logger.info(f"Removed {memory_id} from indices: {removed_from}") + + return results + + async def clear_all(self) -> dict[str, int]: + """ + Clear all indices. + + Returns: + Dictionary of index type to count cleared + """ + return { + "vector": await self.vector_index.clear(), + "temporal": await self.temporal_index.clear(), + "entity": await self.entity_index.clear(), + "outcome": await self.outcome_index.clear(), + } + + async def get_stats(self) -> dict[str, int]: + """ + Get statistics for all indices. + + Returns: + Dictionary of index type to entry count + """ + return { + "vector": await self.vector_index.count(), + "temporal": await self.temporal_index.count(), + "entity": await self.entity_index.count(), + "outcome": await self.outcome_index.count(), + } + + +# Singleton indexer instance +_indexer: MemoryIndexer | None = None + + +def get_memory_indexer() -> MemoryIndexer: + """Get the singleton memory indexer instance.""" + global _indexer + if _indexer is None: + _indexer = MemoryIndexer() + return _indexer diff --git a/backend/app/services/memory/indexing/retrieval.py b/backend/app/services/memory/indexing/retrieval.py new file mode 100644 index 0000000..a2f60f4 --- /dev/null +++ b/backend/app/services/memory/indexing/retrieval.py @@ -0,0 +1,750 @@ +# app/services/memory/indexing/retrieval.py +""" +Memory Retrieval Engine. + +Provides hybrid retrieval capabilities combining: +- Vector similarity search +- Temporal filtering +- Entity filtering +- Outcome filtering +- Relevance scoring +- Result caching +""" + +import hashlib +import logging +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any, TypeVar +from uuid import UUID + +from app.services.memory.types import ( + Episode, + Fact, + MemoryType, + Outcome, + Procedure, + RetrievalResult, +) + +from .index import ( + MemoryIndexer, + get_memory_indexer, +) + +logger = logging.getLogger(__name__) + +T = TypeVar("T", Episode, Fact, Procedure) + + +def _utcnow() -> datetime: + """Get current UTC time as timezone-aware datetime.""" + return datetime.now(UTC) + + +@dataclass +class RetrievalQuery: + """Query parameters for memory retrieval.""" + + # Text/semantic query + query_text: str | None = None + query_embedding: list[float] | None = None + + # Temporal filters + start_time: datetime | None = None + end_time: datetime | None = None + recent_seconds: float | None = None + + # Entity filters + entities: list[tuple[str, str]] | None = None + entity_match_all: bool = False + + # Outcome filters + outcomes: list[Outcome] | None = None + + # Memory type filter + memory_types: list[MemoryType] | None = None + + # Result options + limit: int = 10 + min_relevance: float = 0.0 + + # Retrieval mode + use_vector: bool = True + use_temporal: bool = True + use_entity: bool = True + use_outcome: bool = True + + def to_cache_key(self) -> str: + """Generate a cache key for this query.""" + key_parts = [ + self.query_text or "", + str(self.start_time), + str(self.end_time), + str(self.recent_seconds), + str(self.entities), + str(self.outcomes), + str(self.memory_types), + str(self.limit), + str(self.min_relevance), + ] + key_string = "|".join(key_parts) + return hashlib.sha256(key_string.encode()).hexdigest()[:32] + + +@dataclass +class ScoredResult: + """A retrieval result with relevance score.""" + + memory_id: UUID + memory_type: MemoryType + relevance_score: float + score_breakdown: dict[str, float] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class CacheEntry: + """A cached retrieval result.""" + + results: list[ScoredResult] + created_at: datetime + ttl_seconds: float + query_key: str + + def is_expired(self) -> bool: + """Check if this cache entry has expired.""" + age = (_utcnow() - self.created_at).total_seconds() + return age > self.ttl_seconds + + +class RelevanceScorer: + """ + Calculates relevance scores for retrieved memories. + + Combines multiple signals: + - Vector similarity (if available) + - Temporal recency + - Entity match count + - Outcome preference + - Importance/confidence + """ + + def __init__( + self, + vector_weight: float = 0.4, + recency_weight: float = 0.2, + entity_weight: float = 0.2, + outcome_weight: float = 0.1, + importance_weight: float = 0.1, + ) -> None: + """ + Initialize the relevance scorer. + + Args: + vector_weight: Weight for vector similarity (0-1) + recency_weight: Weight for temporal recency (0-1) + entity_weight: Weight for entity matches (0-1) + outcome_weight: Weight for outcome preference (0-1) + importance_weight: Weight for importance score (0-1) + """ + total = ( + vector_weight + + recency_weight + + entity_weight + + outcome_weight + + importance_weight + ) + # Normalize weights + self.vector_weight = vector_weight / total + self.recency_weight = recency_weight / total + self.entity_weight = entity_weight / total + self.outcome_weight = outcome_weight / total + self.importance_weight = importance_weight / total + + def score( + self, + memory_id: UUID, + memory_type: MemoryType, + vector_similarity: float | None = None, + timestamp: datetime | None = None, + entity_match_count: int = 0, + entity_total: int = 1, + outcome: Outcome | None = None, + importance: float = 0.5, + preferred_outcomes: list[Outcome] | None = None, + ) -> ScoredResult: + """ + Calculate a relevance score for a memory. + + Args: + memory_id: ID of the memory + memory_type: Type of memory + vector_similarity: Similarity score from vector search (0-1) + timestamp: Timestamp of the memory + entity_match_count: Number of matching entities + entity_total: Total entities in query + outcome: Outcome of the memory + importance: Importance score of the memory (0-1) + preferred_outcomes: Outcomes to prefer + + Returns: + Scored result with breakdown + """ + breakdown: dict[str, float] = {} + + # Vector similarity score + if vector_similarity is not None: + breakdown["vector"] = vector_similarity + else: + breakdown["vector"] = 0.5 # Neutral if no vector + + # Recency score (exponential decay) + if timestamp: + age_hours = (_utcnow() - timestamp).total_seconds() / 3600 + # Decay with half-life of 24 hours + breakdown["recency"] = 2 ** (-age_hours / 24) + else: + breakdown["recency"] = 0.5 + + # Entity match score + if entity_total > 0: + breakdown["entity"] = entity_match_count / entity_total + else: + breakdown["entity"] = 1.0 # No entity filter = full score + + # Outcome score + if preferred_outcomes and outcome: + breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0 + else: + breakdown["outcome"] = 0.5 # Neutral if no preference + + # Importance score + breakdown["importance"] = importance + + # Calculate weighted sum + total_score = ( + breakdown["vector"] * self.vector_weight + + breakdown["recency"] * self.recency_weight + + breakdown["entity"] * self.entity_weight + + breakdown["outcome"] * self.outcome_weight + + breakdown["importance"] * self.importance_weight + ) + + return ScoredResult( + memory_id=memory_id, + memory_type=memory_type, + relevance_score=total_score, + score_breakdown=breakdown, + ) + + +class RetrievalCache: + """ + In-memory cache for retrieval results. + + Supports TTL-based expiration and LRU eviction. + """ + + def __init__( + self, + max_entries: int = 1000, + default_ttl_seconds: float = 300, + ) -> None: + """ + Initialize the cache. + + Args: + max_entries: Maximum cache entries + default_ttl_seconds: Default TTL for entries + """ + self._cache: dict[str, CacheEntry] = {} + self._max_entries = max_entries + self._default_ttl = default_ttl_seconds + self._access_order: list[str] = [] + logger.info( + f"Initialized RetrievalCache with max_entries={max_entries}, " + f"ttl={default_ttl_seconds}s" + ) + + def get(self, query_key: str) -> list[ScoredResult] | None: + """ + Get cached results for a query. + + Args: + query_key: Cache key for the query + + Returns: + Cached results or None if not found/expired + """ + if query_key not in self._cache: + return None + + entry = self._cache[query_key] + if entry.is_expired(): + del self._cache[query_key] + if query_key in self._access_order: + self._access_order.remove(query_key) + return None + + # Update access order (LRU) + if query_key in self._access_order: + self._access_order.remove(query_key) + self._access_order.append(query_key) + + logger.debug(f"Cache hit for {query_key}") + return entry.results + + def put( + self, + query_key: str, + results: list[ScoredResult], + ttl_seconds: float | None = None, + ) -> None: + """ + Cache results for a query. + + Args: + query_key: Cache key for the query + results: Results to cache + ttl_seconds: TTL for this entry (or default) + """ + # Evict if at capacity + while len(self._cache) >= self._max_entries and self._access_order: + oldest_key = self._access_order.pop(0) + if oldest_key in self._cache: + del self._cache[oldest_key] + + entry = CacheEntry( + results=results, + created_at=_utcnow(), + ttl_seconds=ttl_seconds or self._default_ttl, + query_key=query_key, + ) + + self._cache[query_key] = entry + self._access_order.append(query_key) + logger.debug(f"Cached {len(results)} results for {query_key}") + + def invalidate(self, query_key: str) -> bool: + """ + Invalidate a specific cache entry. + + Args: + query_key: Cache key to invalidate + + Returns: + True if entry was found and removed + """ + if query_key in self._cache: + del self._cache[query_key] + if query_key in self._access_order: + self._access_order.remove(query_key) + return True + return False + + def invalidate_by_memory(self, memory_id: UUID) -> int: + """ + Invalidate all cache entries containing a specific memory. + + Args: + memory_id: Memory ID to invalidate + + Returns: + Number of entries invalidated + """ + keys_to_remove = [] + for key, entry in self._cache.items(): + if any(r.memory_id == memory_id for r in entry.results): + keys_to_remove.append(key) + + for key in keys_to_remove: + self.invalidate(key) + + if keys_to_remove: + logger.debug( + f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}" + ) + return len(keys_to_remove) + + def clear(self) -> int: + """ + Clear all cache entries. + + Returns: + Number of entries cleared + """ + count = len(self._cache) + self._cache.clear() + self._access_order.clear() + logger.info(f"Cleared {count} cache entries") + return count + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics.""" + expired_count = sum(1 for e in self._cache.values() if e.is_expired()) + return { + "total_entries": len(self._cache), + "expired_entries": expired_count, + "max_entries": self._max_entries, + "default_ttl_seconds": self._default_ttl, + } + + +class RetrievalEngine: + """ + Hybrid retrieval engine for memory search. + + Combines multiple index types for comprehensive retrieval: + - Vector search for semantic similarity + - Temporal index for time-based filtering + - Entity index for entity-based lookups + - Outcome index for success/failure filtering + + Results are scored and ranked using relevance scoring. + """ + + def __init__( + self, + indexer: MemoryIndexer | None = None, + scorer: RelevanceScorer | None = None, + cache: RetrievalCache | None = None, + enable_cache: bool = True, + ) -> None: + """ + Initialize the retrieval engine. + + Args: + indexer: Memory indexer (defaults to singleton) + scorer: Relevance scorer (defaults to new instance) + cache: Retrieval cache (defaults to new instance) + enable_cache: Whether to enable result caching + """ + self._indexer = indexer or get_memory_indexer() + self._scorer = scorer or RelevanceScorer() + self._cache = cache or RetrievalCache() if enable_cache else None + self._enable_cache = enable_cache + logger.info(f"Initialized RetrievalEngine with cache={enable_cache}") + + async def retrieve( + self, + query: RetrievalQuery, + use_cache: bool = True, + ) -> RetrievalResult[ScoredResult]: + """ + Retrieve relevant memories using hybrid search. + + Args: + query: Retrieval query parameters + use_cache: Whether to use cached results + + Returns: + Retrieval result with scored items + """ + start_time = _utcnow() + + # Check cache + cache_key = query.to_cache_key() + if use_cache and self._cache: + cached = self._cache.get(cache_key) + if cached: + latency = (_utcnow() - start_time).total_seconds() * 1000 + return RetrievalResult( + items=cached, + total_count=len(cached), + query=query.query_text or "", + retrieval_type="cached", + latency_ms=latency, + metadata={"cache_hit": True}, + ) + + # Collect candidates from each index + candidates: dict[UUID, dict[str, Any]] = {} + + # Vector search + if query.use_vector and query.query_embedding: + vector_results = await self._indexer.vector_index.search( + query=query.query_embedding, + limit=query.limit * 3, # Get more for filtering + min_similarity=query.min_relevance, + memory_type=query.memory_types[0] if query.memory_types else None, + ) + for entry in vector_results: + if entry.memory_id not in candidates: + candidates[entry.memory_id] = { + "memory_type": entry.memory_type, + "sources": [], + } + candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get( + "similarity", 0.5 + ) + candidates[entry.memory_id]["sources"].append("vector") + + # Temporal search + if query.use_temporal and ( + query.start_time or query.end_time or query.recent_seconds + ): + temporal_results = await self._indexer.temporal_index.search( + query=None, + limit=query.limit * 3, + start_time=query.start_time, + end_time=query.end_time, + recent_seconds=query.recent_seconds, + memory_type=query.memory_types[0] if query.memory_types else None, + ) + for temporal_entry in temporal_results: + if temporal_entry.memory_id not in candidates: + candidates[temporal_entry.memory_id] = { + "memory_type": temporal_entry.memory_type, + "sources": [], + } + candidates[temporal_entry.memory_id]["timestamp"] = ( + temporal_entry.timestamp + ) + candidates[temporal_entry.memory_id]["sources"].append("temporal") + + # Entity search + if query.use_entity and query.entities: + entity_results = await self._indexer.entity_index.search( + query=None, + limit=query.limit * 3, + entities=query.entities, + match_all=query.entity_match_all, + memory_type=query.memory_types[0] if query.memory_types else None, + ) + for entity_entry in entity_results: + if entity_entry.memory_id not in candidates: + candidates[entity_entry.memory_id] = { + "memory_type": entity_entry.memory_type, + "sources": [], + } + # Count entity matches + entity_count = candidates[entity_entry.memory_id].get( + "entity_match_count", 0 + ) + candidates[entity_entry.memory_id]["entity_match_count"] = ( + entity_count + 1 + ) + candidates[entity_entry.memory_id]["sources"].append("entity") + + # Outcome search + if query.use_outcome and query.outcomes: + outcome_results = await self._indexer.outcome_index.search( + query=None, + limit=query.limit * 3, + outcomes=query.outcomes, + memory_type=query.memory_types[0] if query.memory_types else None, + ) + for outcome_entry in outcome_results: + if outcome_entry.memory_id not in candidates: + candidates[outcome_entry.memory_id] = { + "memory_type": outcome_entry.memory_type, + "sources": [], + } + candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome + candidates[outcome_entry.memory_id]["sources"].append("outcome") + + # Score and rank candidates + scored_results: list[ScoredResult] = [] + entity_total = len(query.entities) if query.entities else 1 + + for memory_id, data in candidates.items(): + scored = self._scorer.score( + memory_id=memory_id, + memory_type=data["memory_type"], + vector_similarity=data.get("vector_similarity"), + timestamp=data.get("timestamp"), + entity_match_count=data.get("entity_match_count", 0), + entity_total=entity_total, + outcome=data.get("outcome"), + preferred_outcomes=query.outcomes, + ) + scored.metadata["sources"] = data.get("sources", []) + + # Filter by minimum relevance + if scored.relevance_score >= query.min_relevance: + scored_results.append(scored) + + # Sort by relevance score + scored_results.sort(key=lambda x: x.relevance_score, reverse=True) + + # Apply limit + final_results = scored_results[: query.limit] + + # Cache results + if use_cache and self._cache and final_results: + self._cache.put(cache_key, final_results) + + latency = (_utcnow() - start_time).total_seconds() * 1000 + + logger.info( + f"Retrieved {len(final_results)} results from {len(candidates)} candidates " + f"in {latency:.2f}ms" + ) + + return RetrievalResult( + items=final_results, + total_count=len(candidates), + query=query.query_text or "", + retrieval_type="hybrid", + latency_ms=latency, + metadata={ + "cache_hit": False, + "candidates_count": len(candidates), + "filtered_count": len(scored_results), + }, + ) + + async def retrieve_similar( + self, + embedding: list[float], + limit: int = 10, + min_similarity: float = 0.5, + memory_types: list[MemoryType] | None = None, + ) -> RetrievalResult[ScoredResult]: + """ + Retrieve memories similar to a given embedding. + + Args: + embedding: Query embedding + limit: Maximum results + min_similarity: Minimum similarity threshold + memory_types: Filter by memory types + + Returns: + Retrieval result with scored items + """ + query = RetrievalQuery( + query_embedding=embedding, + limit=limit, + min_relevance=min_similarity, + memory_types=memory_types, + use_temporal=False, + use_entity=False, + use_outcome=False, + ) + return await self.retrieve(query) + + async def retrieve_recent( + self, + hours: float = 24, + limit: int = 10, + memory_types: list[MemoryType] | None = None, + ) -> RetrievalResult[ScoredResult]: + """ + Retrieve recent memories. + + Args: + hours: Number of hours to look back + limit: Maximum results + memory_types: Filter by memory types + + Returns: + Retrieval result with scored items + """ + query = RetrievalQuery( + recent_seconds=hours * 3600, + limit=limit, + memory_types=memory_types, + use_vector=False, + use_entity=False, + use_outcome=False, + ) + return await self.retrieve(query) + + async def retrieve_by_entity( + self, + entity_type: str, + entity_value: str, + limit: int = 10, + memory_types: list[MemoryType] | None = None, + ) -> RetrievalResult[ScoredResult]: + """ + Retrieve memories by entity. + + Args: + entity_type: Type of entity + entity_value: Entity value + limit: Maximum results + memory_types: Filter by memory types + + Returns: + Retrieval result with scored items + """ + query = RetrievalQuery( + entities=[(entity_type, entity_value)], + limit=limit, + memory_types=memory_types, + use_vector=False, + use_temporal=False, + use_outcome=False, + ) + return await self.retrieve(query) + + async def retrieve_successful( + self, + limit: int = 10, + memory_types: list[MemoryType] | None = None, + ) -> RetrievalResult[ScoredResult]: + """ + Retrieve successful memories. + + Args: + limit: Maximum results + memory_types: Filter by memory types + + Returns: + Retrieval result with scored items + """ + query = RetrievalQuery( + outcomes=[Outcome.SUCCESS], + limit=limit, + memory_types=memory_types, + use_vector=False, + use_temporal=False, + use_entity=False, + ) + return await self.retrieve(query) + + def invalidate_cache(self) -> int: + """ + Invalidate all cached results. + + Returns: + Number of entries invalidated + """ + if self._cache: + return self._cache.clear() + return 0 + + def invalidate_cache_for_memory(self, memory_id: UUID) -> int: + """ + Invalidate cache entries containing a specific memory. + + Args: + memory_id: Memory ID to invalidate + + Returns: + Number of entries invalidated + """ + if self._cache: + return self._cache.invalidate_by_memory(memory_id) + return 0 + + def get_cache_stats(self) -> dict[str, Any]: + """Get cache statistics.""" + if self._cache: + return self._cache.get_stats() + return {"enabled": False} + + +# Singleton retrieval engine instance +_engine: RetrievalEngine | None = None + + +def get_retrieval_engine() -> RetrievalEngine: + """Get the singleton retrieval engine instance.""" + global _engine + if _engine is None: + _engine = RetrievalEngine() + return _engine diff --git a/backend/tests/unit/services/memory/indexing/__init__.py b/backend/tests/unit/services/memory/indexing/__init__.py new file mode 100644 index 0000000..415145e --- /dev/null +++ b/backend/tests/unit/services/memory/indexing/__init__.py @@ -0,0 +1,2 @@ +# tests/unit/services/memory/indexing/__init__.py +"""Unit tests for memory indexing.""" diff --git a/backend/tests/unit/services/memory/indexing/test_index.py b/backend/tests/unit/services/memory/indexing/test_index.py new file mode 100644 index 0000000..214bdd7 --- /dev/null +++ b/backend/tests/unit/services/memory/indexing/test_index.py @@ -0,0 +1,497 @@ +# tests/unit/services/memory/indexing/test_index.py +"""Unit tests for memory indexing.""" + +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + +import pytest + +from app.services.memory.indexing.index import ( + EntityIndex, + MemoryIndexer, + OutcomeIndex, + TemporalIndex, + VectorIndex, + get_memory_indexer, +) +from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure + + +def _utcnow() -> datetime: + """Get current UTC time.""" + return datetime.now(UTC) + + +def make_episode( + embedding: list[float] | None = None, + outcome: Outcome = Outcome.SUCCESS, + occurred_at: datetime | None = None, +) -> Episode: + """Create a test episode.""" + return Episode( + id=uuid4(), + project_id=uuid4(), + agent_instance_id=uuid4(), + agent_type_id=uuid4(), + session_id="test-session", + task_type="test_task", + task_description="Test task description", + actions=[{"action": "test"}], + context_summary="Test context", + outcome=outcome, + outcome_details="Test outcome", + duration_seconds=10.0, + tokens_used=100, + lessons_learned=["lesson1"], + importance_score=0.8, + embedding=embedding, + occurred_at=occurred_at or _utcnow(), + created_at=_utcnow(), + updated_at=_utcnow(), + ) + + +def make_fact( + embedding: list[float] | None = None, + subject: str = "test_subject", + predicate: str = "has_property", + obj: str = "test_value", +) -> Fact: + """Create a test fact.""" + return Fact( + id=uuid4(), + project_id=uuid4(), + subject=subject, + predicate=predicate, + object=obj, + confidence=0.9, + source_episode_ids=[uuid4()], + first_learned=_utcnow(), + last_reinforced=_utcnow(), + reinforcement_count=1, + embedding=embedding, + created_at=_utcnow(), + updated_at=_utcnow(), + ) + + +def make_procedure( + embedding: list[float] | None = None, + success_count: int = 8, + failure_count: int = 2, +) -> Procedure: + """Create a test procedure.""" + return Procedure( + id=uuid4(), + project_id=uuid4(), + agent_type_id=uuid4(), + name="test_procedure", + trigger_pattern="test.*", + steps=[{"step": 1, "action": "test"}], + success_count=success_count, + failure_count=failure_count, + last_used=_utcnow(), + embedding=embedding, + created_at=_utcnow(), + updated_at=_utcnow(), + ) + + +class TestVectorIndex: + """Tests for VectorIndex.""" + + @pytest.fixture + def index(self) -> VectorIndex[Episode]: + """Create a vector index.""" + return VectorIndex[Episode](dimension=4) + + @pytest.mark.asyncio + async def test_add_item(self, index: VectorIndex[Episode]) -> None: + """Test adding an item to the index.""" + episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + + entry = await index.add(episode) + + assert entry.memory_id == episode.id + assert entry.memory_type == MemoryType.EPISODIC + assert entry.dimension == 4 + assert await index.count() == 1 + + @pytest.mark.asyncio + async def test_remove_item(self, index: VectorIndex[Episode]) -> None: + """Test removing an item from the index.""" + episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + await index.add(episode) + + result = await index.remove(episode.id) + + assert result is True + assert await index.count() == 0 + + @pytest.mark.asyncio + async def test_remove_nonexistent(self, index: VectorIndex[Episode]) -> None: + """Test removing a nonexistent item.""" + result = await index.remove(uuid4()) + assert result is False + + @pytest.mark.asyncio + async def test_search_similar(self, index: VectorIndex[Episode]) -> None: + """Test searching for similar items.""" + # Add items with different embeddings + e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0]) + e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) + + await index.add(e1) + await index.add(e2) + await index.add(e3) + + # Search for similar to first + results = await index.search([1.0, 0.0, 0.0, 0.0], limit=2) + + assert len(results) == 2 + # First result should be most similar + assert results[0].memory_id == e1.id + + @pytest.mark.asyncio + async def test_search_min_similarity(self, index: VectorIndex[Episode]) -> None: + """Test minimum similarity threshold.""" + e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) # Orthogonal + + await index.add(e1) + await index.add(e2) + + # Search with high threshold + results = await index.search([1.0, 0.0, 0.0, 0.0], min_similarity=0.9) + + assert len(results) == 1 + assert results[0].memory_id == e1.id + + @pytest.mark.asyncio + async def test_search_empty_query(self, index: VectorIndex[Episode]) -> None: + """Test search with empty query.""" + e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + await index.add(e1) + + results = await index.search([], limit=10) + assert len(results) == 0 + + @pytest.mark.asyncio + async def test_clear(self, index: VectorIndex[Episode]) -> None: + """Test clearing the index.""" + await index.add(make_episode(embedding=[1.0, 0.0, 0.0, 0.0])) + await index.add(make_episode(embedding=[0.0, 1.0, 0.0, 0.0])) + + count = await index.clear() + + assert count == 2 + assert await index.count() == 0 + + +class TestTemporalIndex: + """Tests for TemporalIndex.""" + + @pytest.fixture + def index(self) -> TemporalIndex[Episode]: + """Create a temporal index.""" + return TemporalIndex[Episode]() + + @pytest.mark.asyncio + async def test_add_item(self, index: TemporalIndex[Episode]) -> None: + """Test adding an item.""" + episode = make_episode() + entry = await index.add(episode) + + assert entry.memory_id == episode.id + assert await index.count() == 1 + + @pytest.mark.asyncio + async def test_search_by_time_range(self, index: TemporalIndex[Episode]) -> None: + """Test searching by time range.""" + now = _utcnow() + old = make_episode(occurred_at=now - timedelta(hours=2)) + recent = make_episode(occurred_at=now - timedelta(hours=1)) + newest = make_episode(occurred_at=now) + + await index.add(old) + await index.add(recent) + await index.add(newest) + + # Search last hour + results = await index.search( + query=None, + start_time=now - timedelta(hours=1, minutes=30), + end_time=now, + ) + + assert len(results) == 2 + + @pytest.mark.asyncio + async def test_search_recent(self, index: TemporalIndex[Episode]) -> None: + """Test searching for recent items.""" + now = _utcnow() + old = make_episode(occurred_at=now - timedelta(hours=2)) + recent = make_episode(occurred_at=now - timedelta(minutes=30)) + + await index.add(old) + await index.add(recent) + + # Search last hour (3600 seconds) + results = await index.search(query=None, recent_seconds=3600) + + assert len(results) == 1 + assert results[0].memory_id == recent.id + + @pytest.mark.asyncio + async def test_search_order(self, index: TemporalIndex[Episode]) -> None: + """Test result ordering.""" + now = _utcnow() + e1 = make_episode(occurred_at=now - timedelta(hours=2)) + e2 = make_episode(occurred_at=now - timedelta(hours=1)) + e3 = make_episode(occurred_at=now) + + await index.add(e1) + await index.add(e2) + await index.add(e3) + + # Descending order (newest first) + results_desc = await index.search(query=None, order="desc", limit=10) + assert results_desc[0].memory_id == e3.id + + # Ascending order (oldest first) + results_asc = await index.search(query=None, order="asc", limit=10) + assert results_asc[0].memory_id == e1.id + + +class TestEntityIndex: + """Tests for EntityIndex.""" + + @pytest.fixture + def index(self) -> EntityIndex[Fact]: + """Create an entity index.""" + return EntityIndex[Fact]() + + @pytest.mark.asyncio + async def test_add_item(self, index: EntityIndex[Fact]) -> None: + """Test adding an item.""" + fact = make_fact(subject="user", obj="admin") + entry = await index.add(fact) + + assert entry.memory_id == fact.id + assert await index.count() == 1 + + @pytest.mark.asyncio + async def test_search_by_entity(self, index: EntityIndex[Fact]) -> None: + """Test searching by entity.""" + f1 = make_fact(subject="user", obj="admin") + f2 = make_fact(subject="system", obj="config") + + await index.add(f1) + await index.add(f2) + + results = await index.search( + query=None, + entity_type="subject", + entity_value="user", + ) + + assert len(results) == 1 + assert results[0].memory_id == f1.id + + @pytest.mark.asyncio + async def test_search_multiple_entities(self, index: EntityIndex[Fact]) -> None: + """Test searching with multiple entities.""" + f1 = make_fact(subject="user", obj="admin") + f2 = make_fact(subject="user", obj="guest") + + await index.add(f1) + await index.add(f2) + + # Search for facts about "user" subject + results = await index.search( + query=None, + entities=[("subject", "user")], + ) + + assert len(results) == 2 + + @pytest.mark.asyncio + async def test_search_match_all(self, index: EntityIndex[Fact]) -> None: + """Test matching all entities.""" + f1 = make_fact(subject="user", obj="admin") + f2 = make_fact(subject="user", obj="guest") + + await index.add(f1) + await index.add(f2) + + # Search for user+admin (match all) + results = await index.search( + query=None, + entities=[("subject", "user"), ("object", "admin")], + match_all=True, + ) + + assert len(results) == 1 + assert results[0].memory_id == f1.id + + @pytest.mark.asyncio + async def test_get_entities(self, index: EntityIndex[Fact]) -> None: + """Test getting entities for a memory.""" + fact = make_fact(subject="user", obj="admin") + await index.add(fact) + + entities = await index.get_entities(fact.id) + + assert ("subject", "user") in entities + assert ("object", "admin") in entities + + +class TestOutcomeIndex: + """Tests for OutcomeIndex.""" + + @pytest.fixture + def index(self) -> OutcomeIndex[Episode]: + """Create an outcome index.""" + return OutcomeIndex[Episode]() + + @pytest.mark.asyncio + async def test_add_item(self, index: OutcomeIndex[Episode]) -> None: + """Test adding an item.""" + episode = make_episode(outcome=Outcome.SUCCESS) + entry = await index.add(episode) + + assert entry.memory_id == episode.id + assert entry.outcome == Outcome.SUCCESS + assert await index.count() == 1 + + @pytest.mark.asyncio + async def test_search_by_outcome(self, index: OutcomeIndex[Episode]) -> None: + """Test searching by outcome.""" + success = make_episode(outcome=Outcome.SUCCESS) + failure = make_episode(outcome=Outcome.FAILURE) + + await index.add(success) + await index.add(failure) + + results = await index.search(query=None, outcome=Outcome.SUCCESS) + + assert len(results) == 1 + assert results[0].memory_id == success.id + + @pytest.mark.asyncio + async def test_search_multiple_outcomes(self, index: OutcomeIndex[Episode]) -> None: + """Test searching with multiple outcomes.""" + success = make_episode(outcome=Outcome.SUCCESS) + partial = make_episode(outcome=Outcome.PARTIAL) + failure = make_episode(outcome=Outcome.FAILURE) + + await index.add(success) + await index.add(partial) + await index.add(failure) + + results = await index.search( + query=None, + outcomes=[Outcome.SUCCESS, Outcome.PARTIAL], + ) + + assert len(results) == 2 + + @pytest.mark.asyncio + async def test_get_outcome_stats(self, index: OutcomeIndex[Episode]) -> None: + """Test getting outcome statistics.""" + await index.add(make_episode(outcome=Outcome.SUCCESS)) + await index.add(make_episode(outcome=Outcome.SUCCESS)) + await index.add(make_episode(outcome=Outcome.FAILURE)) + + stats = await index.get_outcome_stats() + + assert stats[Outcome.SUCCESS] == 2 + assert stats[Outcome.FAILURE] == 1 + assert stats[Outcome.PARTIAL] == 0 + + +class TestMemoryIndexer: + """Tests for MemoryIndexer.""" + + @pytest.fixture + def indexer(self) -> MemoryIndexer: + """Create a memory indexer.""" + return MemoryIndexer() + + @pytest.mark.asyncio + async def test_index_episode(self, indexer: MemoryIndexer) -> None: + """Test indexing an episode.""" + episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + + results = await indexer.index(episode) + + assert "vector" in results + assert "temporal" in results + assert "entity" in results + assert "outcome" in results + + @pytest.mark.asyncio + async def test_index_fact(self, indexer: MemoryIndexer) -> None: + """Test indexing a fact.""" + fact = make_fact(embedding=[1.0, 0.0, 0.0, 0.0]) + + results = await indexer.index(fact) + + # Facts don't have outcomes + assert "vector" in results + assert "temporal" in results + assert "entity" in results + assert "outcome" not in results + + @pytest.mark.asyncio + async def test_remove_from_all(self, indexer: MemoryIndexer) -> None: + """Test removing from all indices.""" + episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + await indexer.index(episode) + + results = await indexer.remove(episode.id) + + assert results["vector"] is True + assert results["temporal"] is True + assert results["entity"] is True + assert results["outcome"] is True + + @pytest.mark.asyncio + async def test_clear_all(self, indexer: MemoryIndexer) -> None: + """Test clearing all indices.""" + await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0])) + await indexer.index(make_episode(embedding=[0.0, 1.0, 0.0, 0.0])) + + counts = await indexer.clear_all() + + assert counts["vector"] == 2 + assert counts["temporal"] == 2 + + @pytest.mark.asyncio + async def test_get_stats(self, indexer: MemoryIndexer) -> None: + """Test getting index statistics.""" + await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0])) + + stats = await indexer.get_stats() + + assert stats["vector"] == 1 + assert stats["temporal"] == 1 + assert stats["entity"] == 1 + assert stats["outcome"] == 1 + + +class TestGetMemoryIndexer: + """Tests for singleton getter.""" + + def test_returns_instance(self) -> None: + """Test that getter returns instance.""" + indexer = get_memory_indexer() + assert indexer is not None + assert isinstance(indexer, MemoryIndexer) + + def test_returns_same_instance(self) -> None: + """Test that getter returns same instance.""" + indexer1 = get_memory_indexer() + indexer2 = get_memory_indexer() + assert indexer1 is indexer2 diff --git a/backend/tests/unit/services/memory/indexing/test_retrieval.py b/backend/tests/unit/services/memory/indexing/test_retrieval.py new file mode 100644 index 0000000..91ae357 --- /dev/null +++ b/backend/tests/unit/services/memory/indexing/test_retrieval.py @@ -0,0 +1,450 @@ +# tests/unit/services/memory/indexing/test_retrieval.py +"""Unit tests for memory retrieval.""" + +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + +import pytest + +from app.services.memory.indexing.index import MemoryIndexer +from app.services.memory.indexing.retrieval import ( + RelevanceScorer, + RetrievalCache, + RetrievalEngine, + RetrievalQuery, + ScoredResult, + get_retrieval_engine, +) +from app.services.memory.types import Episode, MemoryType, Outcome + + +def _utcnow() -> datetime: + """Get current UTC time.""" + return datetime.now(UTC) + + +def make_episode( + embedding: list[float] | None = None, + outcome: Outcome = Outcome.SUCCESS, + occurred_at: datetime | None = None, + task_type: str = "test_task", +) -> Episode: + """Create a test episode.""" + return Episode( + id=uuid4(), + project_id=uuid4(), + agent_instance_id=uuid4(), + agent_type_id=uuid4(), + session_id="test-session", + task_type=task_type, + task_description="Test task description", + actions=[{"action": "test"}], + context_summary="Test context", + outcome=outcome, + outcome_details="Test outcome", + duration_seconds=10.0, + tokens_used=100, + lessons_learned=["lesson1"], + importance_score=0.8, + embedding=embedding, + occurred_at=occurred_at or _utcnow(), + created_at=_utcnow(), + updated_at=_utcnow(), + ) + + +class TestRetrievalQuery: + """Tests for RetrievalQuery.""" + + def test_default_values(self) -> None: + """Test default query values.""" + query = RetrievalQuery() + + assert query.query_text is None + assert query.limit == 10 + assert query.min_relevance == 0.0 + assert query.use_vector is True + assert query.use_temporal is True + + def test_cache_key_generation(self) -> None: + """Test cache key generation.""" + query1 = RetrievalQuery(query_text="test", limit=10) + query2 = RetrievalQuery(query_text="test", limit=10) + query3 = RetrievalQuery(query_text="different", limit=10) + + # Same queries should have same key + assert query1.to_cache_key() == query2.to_cache_key() + # Different queries should have different keys + assert query1.to_cache_key() != query3.to_cache_key() + + +class TestScoredResult: + """Tests for ScoredResult.""" + + def test_creation(self) -> None: + """Test creating a scored result.""" + result = ScoredResult( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + relevance_score=0.85, + score_breakdown={"vector": 0.9, "recency": 0.8}, + ) + + assert result.relevance_score == 0.85 + assert result.score_breakdown["vector"] == 0.9 + + +class TestRelevanceScorer: + """Tests for RelevanceScorer.""" + + @pytest.fixture + def scorer(self) -> RelevanceScorer: + """Create a relevance scorer.""" + return RelevanceScorer() + + def test_score_with_vector(self, scorer: RelevanceScorer) -> None: + """Test scoring with vector similarity.""" + result = scorer.score( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + vector_similarity=0.9, + ) + + assert result.relevance_score > 0 + assert result.score_breakdown["vector"] == 0.9 + + def test_score_with_recency(self, scorer: RelevanceScorer) -> None: + """Test scoring with recency.""" + recent_result = scorer.score( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + timestamp=_utcnow(), + ) + + old_result = scorer.score( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + timestamp=_utcnow() - timedelta(days=7), + ) + + # Recent should have higher recency score + assert ( + recent_result.score_breakdown["recency"] + > old_result.score_breakdown["recency"] + ) + + def test_score_with_outcome_preference(self, scorer: RelevanceScorer) -> None: + """Test scoring with outcome preference.""" + success_result = scorer.score( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + outcome=Outcome.SUCCESS, + preferred_outcomes=[Outcome.SUCCESS], + ) + + failure_result = scorer.score( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + outcome=Outcome.FAILURE, + preferred_outcomes=[Outcome.SUCCESS], + ) + + assert success_result.score_breakdown["outcome"] == 1.0 + assert failure_result.score_breakdown["outcome"] == 0.0 + + def test_score_with_entity_match(self, scorer: RelevanceScorer) -> None: + """Test scoring with entity matches.""" + full_match = scorer.score( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + entity_match_count=3, + entity_total=3, + ) + + partial_match = scorer.score( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + entity_match_count=1, + entity_total=3, + ) + + assert ( + full_match.score_breakdown["entity"] + > partial_match.score_breakdown["entity"] + ) + + +class TestRetrievalCache: + """Tests for RetrievalCache.""" + + @pytest.fixture + def cache(self) -> RetrievalCache: + """Create a retrieval cache.""" + return RetrievalCache(max_entries=10, default_ttl_seconds=60) + + def test_put_and_get(self, cache: RetrievalCache) -> None: + """Test putting and getting from cache.""" + results = [ + ScoredResult( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + relevance_score=0.8, + ) + ] + + cache.put("test_key", results) + cached = cache.get("test_key") + + assert cached is not None + assert len(cached) == 1 + + def test_get_nonexistent(self, cache: RetrievalCache) -> None: + """Test getting nonexistent entry.""" + result = cache.get("nonexistent") + assert result is None + + def test_lru_eviction(self) -> None: + """Test LRU eviction when at capacity.""" + cache = RetrievalCache(max_entries=2, default_ttl_seconds=60) + + results = [ + ScoredResult( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + relevance_score=0.8, + ) + ] + + cache.put("key1", results) + cache.put("key2", results) + cache.put("key3", results) # Should evict key1 + + assert cache.get("key1") is None + assert cache.get("key2") is not None + assert cache.get("key3") is not None + + def test_invalidate(self, cache: RetrievalCache) -> None: + """Test invalidating a cache entry.""" + results = [ + ScoredResult( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + relevance_score=0.8, + ) + ] + + cache.put("test_key", results) + removed = cache.invalidate("test_key") + + assert removed is True + assert cache.get("test_key") is None + + def test_invalidate_by_memory(self, cache: RetrievalCache) -> None: + """Test invalidating by memory ID.""" + memory_id = uuid4() + results = [ + ScoredResult( + memory_id=memory_id, + memory_type=MemoryType.EPISODIC, + relevance_score=0.8, + ) + ] + + cache.put("key1", results) + cache.put("key2", results) + + count = cache.invalidate_by_memory(memory_id) + + assert count == 2 + assert cache.get("key1") is None + assert cache.get("key2") is None + + def test_clear(self, cache: RetrievalCache) -> None: + """Test clearing the cache.""" + results = [ + ScoredResult( + memory_id=uuid4(), + memory_type=MemoryType.EPISODIC, + relevance_score=0.8, + ) + ] + + cache.put("key1", results) + cache.put("key2", results) + + count = cache.clear() + + assert count == 2 + assert cache.get("key1") is None + + def test_get_stats(self, cache: RetrievalCache) -> None: + """Test getting cache statistics.""" + stats = cache.get_stats() + + assert "total_entries" in stats + assert "max_entries" in stats + assert stats["max_entries"] == 10 + + +class TestRetrievalEngine: + """Tests for RetrievalEngine.""" + + @pytest.fixture + def indexer(self) -> MemoryIndexer: + """Create a memory indexer.""" + return MemoryIndexer() + + @pytest.fixture + def engine(self, indexer: MemoryIndexer) -> RetrievalEngine: + """Create a retrieval engine.""" + return RetrievalEngine(indexer=indexer, enable_cache=True) + + @pytest.mark.asyncio + async def test_retrieve_by_vector( + self, engine: RetrievalEngine, indexer: MemoryIndexer + ) -> None: + """Test retrieval by vector similarity.""" + e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0]) + e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) + + await indexer.index(e1) + await indexer.index(e2) + await indexer.index(e3) + + query = RetrievalQuery( + query_embedding=[1.0, 0.0, 0.0, 0.0], + limit=2, + use_temporal=False, + use_entity=False, + use_outcome=False, + ) + + result = await engine.retrieve(query) + + assert len(result.items) > 0 + assert result.retrieval_type == "hybrid" + + @pytest.mark.asyncio + async def test_retrieve_recent( + self, engine: RetrievalEngine, indexer: MemoryIndexer + ) -> None: + """Test retrieval of recent items.""" + now = _utcnow() + old = make_episode(occurred_at=now - timedelta(hours=2)) + recent = make_episode(occurred_at=now - timedelta(minutes=30)) + + await indexer.index(old) + await indexer.index(recent) + + result = await engine.retrieve_recent(hours=1) + + assert len(result.items) == 1 + + @pytest.mark.asyncio + async def test_retrieve_by_entity( + self, engine: RetrievalEngine, indexer: MemoryIndexer + ) -> None: + """Test retrieval by entity.""" + e1 = make_episode(task_type="deploy") + e2 = make_episode(task_type="test") + + await indexer.index(e1) + await indexer.index(e2) + + result = await engine.retrieve_by_entity("task_type", "deploy") + + assert len(result.items) == 1 + + @pytest.mark.asyncio + async def test_retrieve_successful( + self, engine: RetrievalEngine, indexer: MemoryIndexer + ) -> None: + """Test retrieval of successful items.""" + success = make_episode(outcome=Outcome.SUCCESS) + failure = make_episode(outcome=Outcome.FAILURE) + + await indexer.index(success) + await indexer.index(failure) + + result = await engine.retrieve_successful() + + assert len(result.items) == 1 + # Check outcome index was used + assert result.items[0].memory_id == success.id + + @pytest.mark.asyncio + async def test_retrieve_with_cache( + self, engine: RetrievalEngine, indexer: MemoryIndexer + ) -> None: + """Test that retrieval uses cache.""" + episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + await indexer.index(episode) + + query = RetrievalQuery( + query_embedding=[1.0, 0.0, 0.0, 0.0], + limit=10, + ) + + # First retrieval + result1 = await engine.retrieve(query) + assert result1.metadata.get("cache_hit") is False + + # Second retrieval should be cached + result2 = await engine.retrieve(query) + assert result2.metadata.get("cache_hit") is True + + @pytest.mark.asyncio + async def test_invalidate_cache( + self, engine: RetrievalEngine, indexer: MemoryIndexer + ) -> None: + """Test cache invalidation.""" + episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + await indexer.index(episode) + + query = RetrievalQuery(query_embedding=[1.0, 0.0, 0.0, 0.0]) + await engine.retrieve(query) + + count = engine.invalidate_cache() + + assert count > 0 + + @pytest.mark.asyncio + async def test_retrieve_similar( + self, engine: RetrievalEngine, indexer: MemoryIndexer + ) -> None: + """Test retrieve_similar convenience method.""" + e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0]) + e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) + + await indexer.index(e1) + await indexer.index(e2) + + result = await engine.retrieve_similar( + embedding=[1.0, 0.0, 0.0, 0.0], + limit=1, + ) + + assert len(result.items) == 1 + + def test_get_cache_stats(self, engine: RetrievalEngine) -> None: + """Test getting cache statistics.""" + stats = engine.get_cache_stats() + + assert "total_entries" in stats + + +class TestGetRetrievalEngine: + """Tests for singleton getter.""" + + def test_returns_instance(self) -> None: + """Test that getter returns instance.""" + engine = get_retrieval_engine() + assert engine is not None + assert isinstance(engine, RetrievalEngine) + + def test_returns_same_instance(self) -> None: + """Test that getter returns same instance.""" + engine1 = get_retrieval_engine() + engine2 = get_retrieval_engine() + assert engine1 is engine2