# app/services/memory/indexing/retrieval.py """ Memory Retrieval Engine. Provides hybrid retrieval capabilities combining: - Vector similarity search - Temporal filtering - Entity filtering - Outcome filtering - Relevance scoring - Result caching """ import hashlib import logging from dataclasses import dataclass, field from datetime import UTC, datetime from typing import Any, TypeVar from uuid import UUID from app.services.memory.types import ( Episode, Fact, MemoryType, Outcome, Procedure, RetrievalResult, ) from .index import ( MemoryIndexer, get_memory_indexer, ) logger = logging.getLogger(__name__) T = TypeVar("T", Episode, Fact, Procedure) def _utcnow() -> datetime: """Get current UTC time as timezone-aware datetime.""" return datetime.now(UTC) @dataclass class RetrievalQuery: """Query parameters for memory retrieval.""" # Text/semantic query query_text: str | None = None query_embedding: list[float] | None = None # Temporal filters start_time: datetime | None = None end_time: datetime | None = None recent_seconds: float | None = None # Entity filters entities: list[tuple[str, str]] | None = None entity_match_all: bool = False # Outcome filters outcomes: list[Outcome] | None = None # Memory type filter memory_types: list[MemoryType] | None = None # Result options limit: int = 10 min_relevance: float = 0.0 # Retrieval mode use_vector: bool = True use_temporal: bool = True use_entity: bool = True use_outcome: bool = True def to_cache_key(self) -> str: """Generate a cache key for this query.""" key_parts = [ self.query_text or "", str(self.start_time), str(self.end_time), str(self.recent_seconds), str(self.entities), str(self.outcomes), str(self.memory_types), str(self.limit), str(self.min_relevance), ] key_string = "|".join(key_parts) return hashlib.sha256(key_string.encode()).hexdigest()[:32] @dataclass class ScoredResult: """A retrieval result with relevance score.""" memory_id: UUID memory_type: MemoryType relevance_score: float score_breakdown: dict[str, float] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) @dataclass class CacheEntry: """A cached retrieval result.""" results: list[ScoredResult] created_at: datetime ttl_seconds: float query_key: str def is_expired(self) -> bool: """Check if this cache entry has expired.""" age = (_utcnow() - self.created_at).total_seconds() return age > self.ttl_seconds class RelevanceScorer: """ Calculates relevance scores for retrieved memories. Combines multiple signals: - Vector similarity (if available) - Temporal recency - Entity match count - Outcome preference - Importance/confidence """ def __init__( self, vector_weight: float = 0.4, recency_weight: float = 0.2, entity_weight: float = 0.2, outcome_weight: float = 0.1, importance_weight: float = 0.1, ) -> None: """ Initialize the relevance scorer. Args: vector_weight: Weight for vector similarity (0-1) recency_weight: Weight for temporal recency (0-1) entity_weight: Weight for entity matches (0-1) outcome_weight: Weight for outcome preference (0-1) importance_weight: Weight for importance score (0-1) """ total = ( vector_weight + recency_weight + entity_weight + outcome_weight + importance_weight ) # Normalize weights self.vector_weight = vector_weight / total self.recency_weight = recency_weight / total self.entity_weight = entity_weight / total self.outcome_weight = outcome_weight / total self.importance_weight = importance_weight / total def score( self, memory_id: UUID, memory_type: MemoryType, vector_similarity: float | None = None, timestamp: datetime | None = None, entity_match_count: int = 0, entity_total: int = 1, outcome: Outcome | None = None, importance: float = 0.5, preferred_outcomes: list[Outcome] | None = None, ) -> ScoredResult: """ Calculate a relevance score for a memory. Args: memory_id: ID of the memory memory_type: Type of memory vector_similarity: Similarity score from vector search (0-1) timestamp: Timestamp of the memory entity_match_count: Number of matching entities entity_total: Total entities in query outcome: Outcome of the memory importance: Importance score of the memory (0-1) preferred_outcomes: Outcomes to prefer Returns: Scored result with breakdown """ breakdown: dict[str, float] = {} # Vector similarity score if vector_similarity is not None: breakdown["vector"] = vector_similarity else: breakdown["vector"] = 0.5 # Neutral if no vector # Recency score (exponential decay) if timestamp: age_hours = (_utcnow() - timestamp).total_seconds() / 3600 # Decay with half-life of 24 hours breakdown["recency"] = 2 ** (-age_hours / 24) else: breakdown["recency"] = 0.5 # Entity match score if entity_total > 0: breakdown["entity"] = entity_match_count / entity_total else: breakdown["entity"] = 1.0 # No entity filter = full score # Outcome score if preferred_outcomes and outcome: breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0 else: breakdown["outcome"] = 0.5 # Neutral if no preference # Importance score breakdown["importance"] = importance # Calculate weighted sum total_score = ( breakdown["vector"] * self.vector_weight + breakdown["recency"] * self.recency_weight + breakdown["entity"] * self.entity_weight + breakdown["outcome"] * self.outcome_weight + breakdown["importance"] * self.importance_weight ) return ScoredResult( memory_id=memory_id, memory_type=memory_type, relevance_score=total_score, score_breakdown=breakdown, ) class RetrievalCache: """ In-memory cache for retrieval results. Supports TTL-based expiration and LRU eviction. """ def __init__( self, max_entries: int = 1000, default_ttl_seconds: float = 300, ) -> None: """ Initialize the cache. Args: max_entries: Maximum cache entries default_ttl_seconds: Default TTL for entries """ self._cache: dict[str, CacheEntry] = {} self._max_entries = max_entries self._default_ttl = default_ttl_seconds self._access_order: list[str] = [] logger.info( f"Initialized RetrievalCache with max_entries={max_entries}, " f"ttl={default_ttl_seconds}s" ) def get(self, query_key: str) -> list[ScoredResult] | None: """ Get cached results for a query. Args: query_key: Cache key for the query Returns: Cached results or None if not found/expired """ if query_key not in self._cache: return None entry = self._cache[query_key] if entry.is_expired(): del self._cache[query_key] if query_key in self._access_order: self._access_order.remove(query_key) return None # Update access order (LRU) if query_key in self._access_order: self._access_order.remove(query_key) self._access_order.append(query_key) logger.debug(f"Cache hit for {query_key}") return entry.results def put( self, query_key: str, results: list[ScoredResult], ttl_seconds: float | None = None, ) -> None: """ Cache results for a query. Args: query_key: Cache key for the query results: Results to cache ttl_seconds: TTL for this entry (or default) """ # Evict if at capacity while len(self._cache) >= self._max_entries and self._access_order: oldest_key = self._access_order.pop(0) if oldest_key in self._cache: del self._cache[oldest_key] entry = CacheEntry( results=results, created_at=_utcnow(), ttl_seconds=ttl_seconds or self._default_ttl, query_key=query_key, ) self._cache[query_key] = entry self._access_order.append(query_key) logger.debug(f"Cached {len(results)} results for {query_key}") def invalidate(self, query_key: str) -> bool: """ Invalidate a specific cache entry. Args: query_key: Cache key to invalidate Returns: True if entry was found and removed """ if query_key in self._cache: del self._cache[query_key] if query_key in self._access_order: self._access_order.remove(query_key) return True return False def invalidate_by_memory(self, memory_id: UUID) -> int: """ Invalidate all cache entries containing a specific memory. Args: memory_id: Memory ID to invalidate Returns: Number of entries invalidated """ keys_to_remove = [] for key, entry in self._cache.items(): if any(r.memory_id == memory_id for r in entry.results): keys_to_remove.append(key) for key in keys_to_remove: self.invalidate(key) if keys_to_remove: logger.debug( f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}" ) return len(keys_to_remove) def clear(self) -> int: """ Clear all cache entries. Returns: Number of entries cleared """ count = len(self._cache) self._cache.clear() self._access_order.clear() logger.info(f"Cleared {count} cache entries") return count def get_stats(self) -> dict[str, Any]: """Get cache statistics.""" expired_count = sum(1 for e in self._cache.values() if e.is_expired()) return { "total_entries": len(self._cache), "expired_entries": expired_count, "max_entries": self._max_entries, "default_ttl_seconds": self._default_ttl, } class RetrievalEngine: """ Hybrid retrieval engine for memory search. Combines multiple index types for comprehensive retrieval: - Vector search for semantic similarity - Temporal index for time-based filtering - Entity index for entity-based lookups - Outcome index for success/failure filtering Results are scored and ranked using relevance scoring. """ def __init__( self, indexer: MemoryIndexer | None = None, scorer: RelevanceScorer | None = None, cache: RetrievalCache | None = None, enable_cache: bool = True, ) -> None: """ Initialize the retrieval engine. Args: indexer: Memory indexer (defaults to singleton) scorer: Relevance scorer (defaults to new instance) cache: Retrieval cache (defaults to new instance) enable_cache: Whether to enable result caching """ self._indexer = indexer or get_memory_indexer() self._scorer = scorer or RelevanceScorer() self._cache = cache or RetrievalCache() if enable_cache else None self._enable_cache = enable_cache logger.info(f"Initialized RetrievalEngine with cache={enable_cache}") async def retrieve( self, query: RetrievalQuery, use_cache: bool = True, ) -> RetrievalResult[ScoredResult]: """ Retrieve relevant memories using hybrid search. Args: query: Retrieval query parameters use_cache: Whether to use cached results Returns: Retrieval result with scored items """ start_time = _utcnow() # Check cache cache_key = query.to_cache_key() if use_cache and self._cache: cached = self._cache.get(cache_key) if cached: latency = (_utcnow() - start_time).total_seconds() * 1000 return RetrievalResult( items=cached, total_count=len(cached), query=query.query_text or "", retrieval_type="cached", latency_ms=latency, metadata={"cache_hit": True}, ) # Collect candidates from each index candidates: dict[UUID, dict[str, Any]] = {} # Vector search if query.use_vector and query.query_embedding: vector_results = await self._indexer.vector_index.search( query=query.query_embedding, limit=query.limit * 3, # Get more for filtering min_similarity=query.min_relevance, memory_type=query.memory_types[0] if query.memory_types else None, ) for entry in vector_results: if entry.memory_id not in candidates: candidates[entry.memory_id] = { "memory_type": entry.memory_type, "sources": [], } candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get( "similarity", 0.5 ) candidates[entry.memory_id]["sources"].append("vector") # Temporal search if query.use_temporal and ( query.start_time or query.end_time or query.recent_seconds ): temporal_results = await self._indexer.temporal_index.search( query=None, limit=query.limit * 3, start_time=query.start_time, end_time=query.end_time, recent_seconds=query.recent_seconds, memory_type=query.memory_types[0] if query.memory_types else None, ) for temporal_entry in temporal_results: if temporal_entry.memory_id not in candidates: candidates[temporal_entry.memory_id] = { "memory_type": temporal_entry.memory_type, "sources": [], } candidates[temporal_entry.memory_id]["timestamp"] = ( temporal_entry.timestamp ) candidates[temporal_entry.memory_id]["sources"].append("temporal") # Entity search if query.use_entity and query.entities: entity_results = await self._indexer.entity_index.search( query=None, limit=query.limit * 3, entities=query.entities, match_all=query.entity_match_all, memory_type=query.memory_types[0] if query.memory_types else None, ) for entity_entry in entity_results: if entity_entry.memory_id not in candidates: candidates[entity_entry.memory_id] = { "memory_type": entity_entry.memory_type, "sources": [], } # Count entity matches entity_count = candidates[entity_entry.memory_id].get( "entity_match_count", 0 ) candidates[entity_entry.memory_id]["entity_match_count"] = ( entity_count + 1 ) candidates[entity_entry.memory_id]["sources"].append("entity") # Outcome search if query.use_outcome and query.outcomes: outcome_results = await self._indexer.outcome_index.search( query=None, limit=query.limit * 3, outcomes=query.outcomes, memory_type=query.memory_types[0] if query.memory_types else None, ) for outcome_entry in outcome_results: if outcome_entry.memory_id not in candidates: candidates[outcome_entry.memory_id] = { "memory_type": outcome_entry.memory_type, "sources": [], } candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome candidates[outcome_entry.memory_id]["sources"].append("outcome") # Score and rank candidates scored_results: list[ScoredResult] = [] entity_total = len(query.entities) if query.entities else 1 for memory_id, data in candidates.items(): scored = self._scorer.score( memory_id=memory_id, memory_type=data["memory_type"], vector_similarity=data.get("vector_similarity"), timestamp=data.get("timestamp"), entity_match_count=data.get("entity_match_count", 0), entity_total=entity_total, outcome=data.get("outcome"), preferred_outcomes=query.outcomes, ) scored.metadata["sources"] = data.get("sources", []) # Filter by minimum relevance if scored.relevance_score >= query.min_relevance: scored_results.append(scored) # Sort by relevance score scored_results.sort(key=lambda x: x.relevance_score, reverse=True) # Apply limit final_results = scored_results[: query.limit] # Cache results if use_cache and self._cache and final_results: self._cache.put(cache_key, final_results) latency = (_utcnow() - start_time).total_seconds() * 1000 logger.info( f"Retrieved {len(final_results)} results from {len(candidates)} candidates " f"in {latency:.2f}ms" ) return RetrievalResult( items=final_results, total_count=len(candidates), query=query.query_text or "", retrieval_type="hybrid", latency_ms=latency, metadata={ "cache_hit": False, "candidates_count": len(candidates), "filtered_count": len(scored_results), }, ) async def retrieve_similar( self, embedding: list[float], limit: int = 10, min_similarity: float = 0.5, memory_types: list[MemoryType] | None = None, ) -> RetrievalResult[ScoredResult]: """ Retrieve memories similar to a given embedding. Args: embedding: Query embedding limit: Maximum results min_similarity: Minimum similarity threshold memory_types: Filter by memory types Returns: Retrieval result with scored items """ query = RetrievalQuery( query_embedding=embedding, limit=limit, min_relevance=min_similarity, memory_types=memory_types, use_temporal=False, use_entity=False, use_outcome=False, ) return await self.retrieve(query) async def retrieve_recent( self, hours: float = 24, limit: int = 10, memory_types: list[MemoryType] | None = None, ) -> RetrievalResult[ScoredResult]: """ Retrieve recent memories. Args: hours: Number of hours to look back limit: Maximum results memory_types: Filter by memory types Returns: Retrieval result with scored items """ query = RetrievalQuery( recent_seconds=hours * 3600, limit=limit, memory_types=memory_types, use_vector=False, use_entity=False, use_outcome=False, ) return await self.retrieve(query) async def retrieve_by_entity( self, entity_type: str, entity_value: str, limit: int = 10, memory_types: list[MemoryType] | None = None, ) -> RetrievalResult[ScoredResult]: """ Retrieve memories by entity. Args: entity_type: Type of entity entity_value: Entity value limit: Maximum results memory_types: Filter by memory types Returns: Retrieval result with scored items """ query = RetrievalQuery( entities=[(entity_type, entity_value)], limit=limit, memory_types=memory_types, use_vector=False, use_temporal=False, use_outcome=False, ) return await self.retrieve(query) async def retrieve_successful( self, limit: int = 10, memory_types: list[MemoryType] | None = None, ) -> RetrievalResult[ScoredResult]: """ Retrieve successful memories. Args: limit: Maximum results memory_types: Filter by memory types Returns: Retrieval result with scored items """ query = RetrievalQuery( outcomes=[Outcome.SUCCESS], limit=limit, memory_types=memory_types, use_vector=False, use_temporal=False, use_entity=False, ) return await self.retrieve(query) def invalidate_cache(self) -> int: """ Invalidate all cached results. Returns: Number of entries invalidated """ if self._cache: return self._cache.clear() return 0 def invalidate_cache_for_memory(self, memory_id: UUID) -> int: """ Invalidate cache entries containing a specific memory. Args: memory_id: Memory ID to invalidate Returns: Number of entries invalidated """ if self._cache: return self._cache.invalidate_by_memory(memory_id) return 0 def get_cache_stats(self) -> dict[str, Any]: """Get cache statistics.""" if self._cache: return self._cache.get_stats() return {"enabled": False} # Singleton retrieval engine instance _engine: RetrievalEngine | None = None def get_retrieval_engine() -> RetrievalEngine: """Get the singleton retrieval engine instance.""" global _engine if _engine is None: _engine = RetrievalEngine() return _engine