# app/services/memory/indexing/retrieval.py """ Memory Retrieval Engine. Provides hybrid retrieval capabilities combining: - Vector similarity search - Temporal filtering - Entity filtering - Outcome filtering - Relevance scoring - Result caching """ import hashlib import logging from collections import OrderedDict from dataclasses import dataclass, field from datetime import UTC, datetime from typing import Any, TypeVar from uuid import UUID from app.services.memory.types import ( Episode, Fact, MemoryType, Outcome, Procedure, RetrievalResult, ) from .index import ( MemoryIndexer, get_memory_indexer, ) logger = logging.getLogger(__name__) T = TypeVar("T", Episode, Fact, Procedure) def _utcnow() -> datetime: """Get current UTC time as timezone-aware datetime.""" return datetime.now(UTC) @dataclass class RetrievalQuery: """Query parameters for memory retrieval.""" # Text/semantic query query_text: str | None = None query_embedding: list[float] | None = None # Temporal filters start_time: datetime | None = None end_time: datetime | None = None recent_seconds: float | None = None # Entity filters entities: list[tuple[str, str]] | None = None entity_match_all: bool = False # Outcome filters outcomes: list[Outcome] | None = None # Memory type filter memory_types: list[MemoryType] | None = None # Result options limit: int = 10 min_relevance: float = 0.0 # Retrieval mode use_vector: bool = True use_temporal: bool = True use_entity: bool = True use_outcome: bool = True def to_cache_key(self) -> str: """Generate a cache key for this query.""" key_parts = [ self.query_text or "", str(self.start_time), str(self.end_time), str(self.recent_seconds), str(self.entities), str(self.outcomes), str(self.memory_types), str(self.limit), str(self.min_relevance), ] key_string = "|".join(key_parts) return hashlib.sha256(key_string.encode()).hexdigest()[:32] @dataclass class ScoredResult: """A retrieval result with relevance score.""" memory_id: UUID memory_type: MemoryType relevance_score: float score_breakdown: dict[str, float] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) @dataclass class CacheEntry: """A cached retrieval result.""" results: list[ScoredResult] created_at: datetime ttl_seconds: float query_key: str def is_expired(self) -> bool: """Check if this cache entry has expired.""" age = (_utcnow() - self.created_at).total_seconds() return age > self.ttl_seconds class RelevanceScorer: """ Calculates relevance scores for retrieved memories. Combines multiple signals: - Vector similarity (if available) - Temporal recency - Entity match count - Outcome preference - Importance/confidence """ def __init__( self, vector_weight: float = 0.4, recency_weight: float = 0.2, entity_weight: float = 0.2, outcome_weight: float = 0.1, importance_weight: float = 0.1, ) -> None: """ Initialize the relevance scorer. Args: vector_weight: Weight for vector similarity (0-1) recency_weight: Weight for temporal recency (0-1) entity_weight: Weight for entity matches (0-1) outcome_weight: Weight for outcome preference (0-1) importance_weight: Weight for importance score (0-1) """ total = ( vector_weight + recency_weight + entity_weight + outcome_weight + importance_weight ) # Normalize weights self.vector_weight = vector_weight / total self.recency_weight = recency_weight / total self.entity_weight = entity_weight / total self.outcome_weight = outcome_weight / total self.importance_weight = importance_weight / total def score( self, memory_id: UUID, memory_type: MemoryType, vector_similarity: float | None = None, timestamp: datetime | None = None, entity_match_count: int = 0, entity_total: int = 1, outcome: Outcome | None = None, importance: float = 0.5, preferred_outcomes: list[Outcome] | None = None, ) -> ScoredResult: """ Calculate a relevance score for a memory. Args: memory_id: ID of the memory memory_type: Type of memory vector_similarity: Similarity score from vector search (0-1) timestamp: Timestamp of the memory entity_match_count: Number of matching entities entity_total: Total entities in query outcome: Outcome of the memory importance: Importance score of the memory (0-1) preferred_outcomes: Outcomes to prefer Returns: Scored result with breakdown """ breakdown: dict[str, float] = {} # Vector similarity score if vector_similarity is not None: breakdown["vector"] = vector_similarity else: breakdown["vector"] = 0.5 # Neutral if no vector # Recency score (exponential decay) if timestamp: age_hours = (_utcnow() - timestamp).total_seconds() / 3600 # Decay with half-life of 24 hours breakdown["recency"] = 2 ** (-age_hours / 24) else: breakdown["recency"] = 0.5 # Entity match score if entity_total > 0: breakdown["entity"] = entity_match_count / entity_total else: breakdown["entity"] = 1.0 # No entity filter = full score # Outcome score if preferred_outcomes and outcome: breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0 else: breakdown["outcome"] = 0.5 # Neutral if no preference # Importance score breakdown["importance"] = importance # Calculate weighted sum total_score = ( breakdown["vector"] * self.vector_weight + breakdown["recency"] * self.recency_weight + breakdown["entity"] * self.entity_weight + breakdown["outcome"] * self.outcome_weight + breakdown["importance"] * self.importance_weight ) return ScoredResult( memory_id=memory_id, memory_type=memory_type, relevance_score=total_score, score_breakdown=breakdown, ) class RetrievalCache: """ In-memory cache for retrieval results. Supports TTL-based expiration and LRU eviction with O(1) operations. Uses OrderedDict for efficient LRU tracking. """ def __init__( self, max_entries: int = 1000, default_ttl_seconds: float = 300, ) -> None: """ Initialize the cache. Args: max_entries: Maximum cache entries default_ttl_seconds: Default TTL for entries """ # OrderedDict maintains insertion order; we use move_to_end for O(1) LRU self._cache: OrderedDict[str, CacheEntry] = OrderedDict() self._max_entries = max_entries self._default_ttl = default_ttl_seconds logger.info( f"Initialized RetrievalCache with max_entries={max_entries}, " f"ttl={default_ttl_seconds}s" ) def get(self, query_key: str) -> list[ScoredResult] | None: """ Get cached results for a query. Args: query_key: Cache key for the query Returns: Cached results or None if not found/expired """ if query_key not in self._cache: return None entry = self._cache[query_key] if entry.is_expired(): del self._cache[query_key] return None # Update access order (LRU) - O(1) with OrderedDict self._cache.move_to_end(query_key) logger.debug(f"Cache hit for {query_key}") return entry.results def put( self, query_key: str, results: list[ScoredResult], ttl_seconds: float | None = None, ) -> None: """ Cache results for a query. Args: query_key: Cache key for the query results: Results to cache ttl_seconds: TTL for this entry (or default) """ # Evict oldest entries if at capacity - O(1) with popitem(last=False) while len(self._cache) >= self._max_entries: self._cache.popitem(last=False) entry = CacheEntry( results=results, created_at=_utcnow(), ttl_seconds=ttl_seconds or self._default_ttl, query_key=query_key, ) self._cache[query_key] = entry logger.debug(f"Cached {len(results)} results for {query_key}") def invalidate(self, query_key: str) -> bool: """ Invalidate a specific cache entry. Args: query_key: Cache key to invalidate Returns: True if entry was found and removed """ if query_key in self._cache: del self._cache[query_key] return True return False def invalidate_by_memory(self, memory_id: UUID) -> int: """ Invalidate all cache entries containing a specific memory. Args: memory_id: Memory ID to invalidate Returns: Number of entries invalidated """ keys_to_remove = [] for key, entry in self._cache.items(): if any(r.memory_id == memory_id for r in entry.results): keys_to_remove.append(key) for key in keys_to_remove: self.invalidate(key) if keys_to_remove: logger.debug( f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}" ) return len(keys_to_remove) def clear(self) -> int: """ Clear all cache entries. Returns: Number of entries cleared """ count = len(self._cache) self._cache.clear() logger.info(f"Cleared {count} cache entries") return count def get_stats(self) -> dict[str, Any]: """Get cache statistics.""" expired_count = sum(1 for e in self._cache.values() if e.is_expired()) return { "total_entries": len(self._cache), "expired_entries": expired_count, "max_entries": self._max_entries, "default_ttl_seconds": self._default_ttl, } class RetrievalEngine: """ Hybrid retrieval engine for memory search. Combines multiple index types for comprehensive retrieval: - Vector search for semantic similarity - Temporal index for time-based filtering - Entity index for entity-based lookups - Outcome index for success/failure filtering Results are scored and ranked using relevance scoring. """ def __init__( self, indexer: MemoryIndexer | None = None, scorer: RelevanceScorer | None = None, cache: RetrievalCache | None = None, enable_cache: bool = True, ) -> None: """ Initialize the retrieval engine. Args: indexer: Memory indexer (defaults to singleton) scorer: Relevance scorer (defaults to new instance) cache: Retrieval cache (defaults to new instance) enable_cache: Whether to enable result caching """ self._indexer = indexer or get_memory_indexer() self._scorer = scorer or RelevanceScorer() self._cache = cache or RetrievalCache() if enable_cache else None self._enable_cache = enable_cache logger.info(f"Initialized RetrievalEngine with cache={enable_cache}") async def retrieve( self, query: RetrievalQuery, use_cache: bool = True, ) -> RetrievalResult[ScoredResult]: """ Retrieve relevant memories using hybrid search. Args: query: Retrieval query parameters use_cache: Whether to use cached results Returns: Retrieval result with scored items """ start_time = _utcnow() # Check cache cache_key = query.to_cache_key() if use_cache and self._cache: cached = self._cache.get(cache_key) if cached: latency = (_utcnow() - start_time).total_seconds() * 1000 return RetrievalResult( items=cached, total_count=len(cached), query=query.query_text or "", retrieval_type="cached", latency_ms=latency, metadata={"cache_hit": True}, ) # Collect candidates from each index candidates: dict[UUID, dict[str, Any]] = {} # Vector search if query.use_vector and query.query_embedding: vector_results = await self._indexer.vector_index.search( query=query.query_embedding, limit=query.limit * 3, # Get more for filtering min_similarity=query.min_relevance, memory_type=query.memory_types[0] if query.memory_types else None, ) for entry in vector_results: if entry.memory_id not in candidates: candidates[entry.memory_id] = { "memory_type": entry.memory_type, "sources": [], } candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get( "similarity", 0.5 ) candidates[entry.memory_id]["sources"].append("vector") # Temporal search if query.use_temporal and ( query.start_time or query.end_time or query.recent_seconds ): temporal_results = await self._indexer.temporal_index.search( query=None, limit=query.limit * 3, start_time=query.start_time, end_time=query.end_time, recent_seconds=query.recent_seconds, memory_type=query.memory_types[0] if query.memory_types else None, ) for temporal_entry in temporal_results: if temporal_entry.memory_id not in candidates: candidates[temporal_entry.memory_id] = { "memory_type": temporal_entry.memory_type, "sources": [], } candidates[temporal_entry.memory_id]["timestamp"] = ( temporal_entry.timestamp ) candidates[temporal_entry.memory_id]["sources"].append("temporal") # Entity search if query.use_entity and query.entities: entity_results = await self._indexer.entity_index.search( query=None, limit=query.limit * 3, entities=query.entities, match_all=query.entity_match_all, memory_type=query.memory_types[0] if query.memory_types else None, ) for entity_entry in entity_results: if entity_entry.memory_id not in candidates: candidates[entity_entry.memory_id] = { "memory_type": entity_entry.memory_type, "sources": [], } # Count entity matches entity_count = candidates[entity_entry.memory_id].get( "entity_match_count", 0 ) candidates[entity_entry.memory_id]["entity_match_count"] = ( entity_count + 1 ) candidates[entity_entry.memory_id]["sources"].append("entity") # Outcome search if query.use_outcome and query.outcomes: outcome_results = await self._indexer.outcome_index.search( query=None, limit=query.limit * 3, outcomes=query.outcomes, memory_type=query.memory_types[0] if query.memory_types else None, ) for outcome_entry in outcome_results: if outcome_entry.memory_id not in candidates: candidates[outcome_entry.memory_id] = { "memory_type": outcome_entry.memory_type, "sources": [], } candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome candidates[outcome_entry.memory_id]["sources"].append("outcome") # Score and rank candidates scored_results: list[ScoredResult] = [] entity_total = len(query.entities) if query.entities else 1 for memory_id, data in candidates.items(): scored = self._scorer.score( memory_id=memory_id, memory_type=data["memory_type"], vector_similarity=data.get("vector_similarity"), timestamp=data.get("timestamp"), entity_match_count=data.get("entity_match_count", 0), entity_total=entity_total, outcome=data.get("outcome"), preferred_outcomes=query.outcomes, ) scored.metadata["sources"] = data.get("sources", []) # Filter by minimum relevance if scored.relevance_score >= query.min_relevance: scored_results.append(scored) # Sort by relevance score scored_results.sort(key=lambda x: x.relevance_score, reverse=True) # Apply limit final_results = scored_results[: query.limit] # Cache results if use_cache and self._cache and final_results: self._cache.put(cache_key, final_results) latency = (_utcnow() - start_time).total_seconds() * 1000 logger.info( f"Retrieved {len(final_results)} results from {len(candidates)} candidates " f"in {latency:.2f}ms" ) return RetrievalResult( items=final_results, total_count=len(candidates), query=query.query_text or "", retrieval_type="hybrid", latency_ms=latency, metadata={ "cache_hit": False, "candidates_count": len(candidates), "filtered_count": len(scored_results), }, ) async def retrieve_similar( self, embedding: list[float], limit: int = 10, min_similarity: float = 0.5, memory_types: list[MemoryType] | None = None, ) -> RetrievalResult[ScoredResult]: """ Retrieve memories similar to a given embedding. Args: embedding: Query embedding limit: Maximum results min_similarity: Minimum similarity threshold memory_types: Filter by memory types Returns: Retrieval result with scored items """ query = RetrievalQuery( query_embedding=embedding, limit=limit, min_relevance=min_similarity, memory_types=memory_types, use_temporal=False, use_entity=False, use_outcome=False, ) return await self.retrieve(query) async def retrieve_recent( self, hours: float = 24, limit: int = 10, memory_types: list[MemoryType] | None = None, ) -> RetrievalResult[ScoredResult]: """ Retrieve recent memories. Args: hours: Number of hours to look back limit: Maximum results memory_types: Filter by memory types Returns: Retrieval result with scored items """ query = RetrievalQuery( recent_seconds=hours * 3600, limit=limit, memory_types=memory_types, use_vector=False, use_entity=False, use_outcome=False, ) return await self.retrieve(query) async def retrieve_by_entity( self, entity_type: str, entity_value: str, limit: int = 10, memory_types: list[MemoryType] | None = None, ) -> RetrievalResult[ScoredResult]: """ Retrieve memories by entity. Args: entity_type: Type of entity entity_value: Entity value limit: Maximum results memory_types: Filter by memory types Returns: Retrieval result with scored items """ query = RetrievalQuery( entities=[(entity_type, entity_value)], limit=limit, memory_types=memory_types, use_vector=False, use_temporal=False, use_outcome=False, ) return await self.retrieve(query) async def retrieve_successful( self, limit: int = 10, memory_types: list[MemoryType] | None = None, ) -> RetrievalResult[ScoredResult]: """ Retrieve successful memories. Args: limit: Maximum results memory_types: Filter by memory types Returns: Retrieval result with scored items """ query = RetrievalQuery( outcomes=[Outcome.SUCCESS], limit=limit, memory_types=memory_types, use_vector=False, use_temporal=False, use_entity=False, ) return await self.retrieve(query) def invalidate_cache(self) -> int: """ Invalidate all cached results. Returns: Number of entries invalidated """ if self._cache: return self._cache.clear() return 0 def invalidate_cache_for_memory(self, memory_id: UUID) -> int: """ Invalidate cache entries containing a specific memory. Args: memory_id: Memory ID to invalidate Returns: Number of entries invalidated """ if self._cache: return self._cache.invalidate_by_memory(memory_id) return 0 def get_cache_stats(self) -> dict[str, Any]: """Get cache statistics.""" if self._cache: return self._cache.get_stats() return {"enabled": False} # Singleton retrieval engine instance _engine: RetrievalEngine | None = None def get_retrieval_engine() -> RetrievalEngine: """Get the singleton retrieval engine instance.""" global _engine if _engine is None: _engine = RetrievalEngine() return _engine