Files
syndarix/backend/app/services/memory/indexing/retrieval.py
Felipe Cardoso 999b7ac03f feat(memory): implement memory indexing and retrieval engine (#94)
Add comprehensive indexing and retrieval system for memory search:
- VectorIndex for semantic similarity search using cosine similarity
- TemporalIndex for time-based queries with range and recency support
- EntityIndex for entity-based lookups with multi-entity intersection
- OutcomeIndex for success/failure filtering on episodes
- MemoryIndexer as unified interface for all index types
- RetrievalEngine with hybrid search combining all indices
- RelevanceScorer for multi-signal relevance scoring
- RetrievalCache for LRU caching of search results

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 02:50:13 +01:00

751 lines
23 KiB
Python

# app/services/memory/indexing/retrieval.py
"""
Memory Retrieval Engine.
Provides hybrid retrieval capabilities combining:
- Vector similarity search
- Temporal filtering
- Entity filtering
- Outcome filtering
- Relevance scoring
- Result caching
"""
import hashlib
import logging
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any, TypeVar
from uuid import UUID
from app.services.memory.types import (
Episode,
Fact,
MemoryType,
Outcome,
Procedure,
RetrievalResult,
)
from .index import (
MemoryIndexer,
get_memory_indexer,
)
logger = logging.getLogger(__name__)
T = TypeVar("T", Episode, Fact, Procedure)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class RetrievalQuery:
"""Query parameters for memory retrieval."""
# Text/semantic query
query_text: str | None = None
query_embedding: list[float] | None = None
# Temporal filters
start_time: datetime | None = None
end_time: datetime | None = None
recent_seconds: float | None = None
# Entity filters
entities: list[tuple[str, str]] | None = None
entity_match_all: bool = False
# Outcome filters
outcomes: list[Outcome] | None = None
# Memory type filter
memory_types: list[MemoryType] | None = None
# Result options
limit: int = 10
min_relevance: float = 0.0
# Retrieval mode
use_vector: bool = True
use_temporal: bool = True
use_entity: bool = True
use_outcome: bool = True
def to_cache_key(self) -> str:
"""Generate a cache key for this query."""
key_parts = [
self.query_text or "",
str(self.start_time),
str(self.end_time),
str(self.recent_seconds),
str(self.entities),
str(self.outcomes),
str(self.memory_types),
str(self.limit),
str(self.min_relevance),
]
key_string = "|".join(key_parts)
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
@dataclass
class ScoredResult:
"""A retrieval result with relevance score."""
memory_id: UUID
memory_type: MemoryType
relevance_score: float
score_breakdown: dict[str, float] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class CacheEntry:
"""A cached retrieval result."""
results: list[ScoredResult]
created_at: datetime
ttl_seconds: float
query_key: str
def is_expired(self) -> bool:
"""Check if this cache entry has expired."""
age = (_utcnow() - self.created_at).total_seconds()
return age > self.ttl_seconds
class RelevanceScorer:
"""
Calculates relevance scores for retrieved memories.
Combines multiple signals:
- Vector similarity (if available)
- Temporal recency
- Entity match count
- Outcome preference
- Importance/confidence
"""
def __init__(
self,
vector_weight: float = 0.4,
recency_weight: float = 0.2,
entity_weight: float = 0.2,
outcome_weight: float = 0.1,
importance_weight: float = 0.1,
) -> None:
"""
Initialize the relevance scorer.
Args:
vector_weight: Weight for vector similarity (0-1)
recency_weight: Weight for temporal recency (0-1)
entity_weight: Weight for entity matches (0-1)
outcome_weight: Weight for outcome preference (0-1)
importance_weight: Weight for importance score (0-1)
"""
total = (
vector_weight
+ recency_weight
+ entity_weight
+ outcome_weight
+ importance_weight
)
# Normalize weights
self.vector_weight = vector_weight / total
self.recency_weight = recency_weight / total
self.entity_weight = entity_weight / total
self.outcome_weight = outcome_weight / total
self.importance_weight = importance_weight / total
def score(
self,
memory_id: UUID,
memory_type: MemoryType,
vector_similarity: float | None = None,
timestamp: datetime | None = None,
entity_match_count: int = 0,
entity_total: int = 1,
outcome: Outcome | None = None,
importance: float = 0.5,
preferred_outcomes: list[Outcome] | None = None,
) -> ScoredResult:
"""
Calculate a relevance score for a memory.
Args:
memory_id: ID of the memory
memory_type: Type of memory
vector_similarity: Similarity score from vector search (0-1)
timestamp: Timestamp of the memory
entity_match_count: Number of matching entities
entity_total: Total entities in query
outcome: Outcome of the memory
importance: Importance score of the memory (0-1)
preferred_outcomes: Outcomes to prefer
Returns:
Scored result with breakdown
"""
breakdown: dict[str, float] = {}
# Vector similarity score
if vector_similarity is not None:
breakdown["vector"] = vector_similarity
else:
breakdown["vector"] = 0.5 # Neutral if no vector
# Recency score (exponential decay)
if timestamp:
age_hours = (_utcnow() - timestamp).total_seconds() / 3600
# Decay with half-life of 24 hours
breakdown["recency"] = 2 ** (-age_hours / 24)
else:
breakdown["recency"] = 0.5
# Entity match score
if entity_total > 0:
breakdown["entity"] = entity_match_count / entity_total
else:
breakdown["entity"] = 1.0 # No entity filter = full score
# Outcome score
if preferred_outcomes and outcome:
breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0
else:
breakdown["outcome"] = 0.5 # Neutral if no preference
# Importance score
breakdown["importance"] = importance
# Calculate weighted sum
total_score = (
breakdown["vector"] * self.vector_weight
+ breakdown["recency"] * self.recency_weight
+ breakdown["entity"] * self.entity_weight
+ breakdown["outcome"] * self.outcome_weight
+ breakdown["importance"] * self.importance_weight
)
return ScoredResult(
memory_id=memory_id,
memory_type=memory_type,
relevance_score=total_score,
score_breakdown=breakdown,
)
class RetrievalCache:
"""
In-memory cache for retrieval results.
Supports TTL-based expiration and LRU eviction.
"""
def __init__(
self,
max_entries: int = 1000,
default_ttl_seconds: float = 300,
) -> None:
"""
Initialize the cache.
Args:
max_entries: Maximum cache entries
default_ttl_seconds: Default TTL for entries
"""
self._cache: dict[str, CacheEntry] = {}
self._max_entries = max_entries
self._default_ttl = default_ttl_seconds
self._access_order: list[str] = []
logger.info(
f"Initialized RetrievalCache with max_entries={max_entries}, "
f"ttl={default_ttl_seconds}s"
)
def get(self, query_key: str) -> list[ScoredResult] | None:
"""
Get cached results for a query.
Args:
query_key: Cache key for the query
Returns:
Cached results or None if not found/expired
"""
if query_key not in self._cache:
return None
entry = self._cache[query_key]
if entry.is_expired():
del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return None
# Update access order (LRU)
if query_key in self._access_order:
self._access_order.remove(query_key)
self._access_order.append(query_key)
logger.debug(f"Cache hit for {query_key}")
return entry.results
def put(
self,
query_key: str,
results: list[ScoredResult],
ttl_seconds: float | None = None,
) -> None:
"""
Cache results for a query.
Args:
query_key: Cache key for the query
results: Results to cache
ttl_seconds: TTL for this entry (or default)
"""
# Evict if at capacity
while len(self._cache) >= self._max_entries and self._access_order:
oldest_key = self._access_order.pop(0)
if oldest_key in self._cache:
del self._cache[oldest_key]
entry = CacheEntry(
results=results,
created_at=_utcnow(),
ttl_seconds=ttl_seconds or self._default_ttl,
query_key=query_key,
)
self._cache[query_key] = entry
self._access_order.append(query_key)
logger.debug(f"Cached {len(results)} results for {query_key}")
def invalidate(self, query_key: str) -> bool:
"""
Invalidate a specific cache entry.
Args:
query_key: Cache key to invalidate
Returns:
True if entry was found and removed
"""
if query_key in self._cache:
del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return True
return False
def invalidate_by_memory(self, memory_id: UUID) -> int:
"""
Invalidate all cache entries containing a specific memory.
Args:
memory_id: Memory ID to invalidate
Returns:
Number of entries invalidated
"""
keys_to_remove = []
for key, entry in self._cache.items():
if any(r.memory_id == memory_id for r in entry.results):
keys_to_remove.append(key)
for key in keys_to_remove:
self.invalidate(key)
if keys_to_remove:
logger.debug(
f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}"
)
return len(keys_to_remove)
def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries cleared
"""
count = len(self._cache)
self._cache.clear()
self._access_order.clear()
logger.info(f"Cleared {count} cache entries")
return count
def get_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
expired_count = sum(1 for e in self._cache.values() if e.is_expired())
return {
"total_entries": len(self._cache),
"expired_entries": expired_count,
"max_entries": self._max_entries,
"default_ttl_seconds": self._default_ttl,
}
class RetrievalEngine:
"""
Hybrid retrieval engine for memory search.
Combines multiple index types for comprehensive retrieval:
- Vector search for semantic similarity
- Temporal index for time-based filtering
- Entity index for entity-based lookups
- Outcome index for success/failure filtering
Results are scored and ranked using relevance scoring.
"""
def __init__(
self,
indexer: MemoryIndexer | None = None,
scorer: RelevanceScorer | None = None,
cache: RetrievalCache | None = None,
enable_cache: bool = True,
) -> None:
"""
Initialize the retrieval engine.
Args:
indexer: Memory indexer (defaults to singleton)
scorer: Relevance scorer (defaults to new instance)
cache: Retrieval cache (defaults to new instance)
enable_cache: Whether to enable result caching
"""
self._indexer = indexer or get_memory_indexer()
self._scorer = scorer or RelevanceScorer()
self._cache = cache or RetrievalCache() if enable_cache else None
self._enable_cache = enable_cache
logger.info(f"Initialized RetrievalEngine with cache={enable_cache}")
async def retrieve(
self,
query: RetrievalQuery,
use_cache: bool = True,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve relevant memories using hybrid search.
Args:
query: Retrieval query parameters
use_cache: Whether to use cached results
Returns:
Retrieval result with scored items
"""
start_time = _utcnow()
# Check cache
cache_key = query.to_cache_key()
if use_cache and self._cache:
cached = self._cache.get(cache_key)
if cached:
latency = (_utcnow() - start_time).total_seconds() * 1000
return RetrievalResult(
items=cached,
total_count=len(cached),
query=query.query_text or "",
retrieval_type="cached",
latency_ms=latency,
metadata={"cache_hit": True},
)
# Collect candidates from each index
candidates: dict[UUID, dict[str, Any]] = {}
# Vector search
if query.use_vector and query.query_embedding:
vector_results = await self._indexer.vector_index.search(
query=query.query_embedding,
limit=query.limit * 3, # Get more for filtering
min_similarity=query.min_relevance,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for entry in vector_results:
if entry.memory_id not in candidates:
candidates[entry.memory_id] = {
"memory_type": entry.memory_type,
"sources": [],
}
candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get(
"similarity", 0.5
)
candidates[entry.memory_id]["sources"].append("vector")
# Temporal search
if query.use_temporal and (
query.start_time or query.end_time or query.recent_seconds
):
temporal_results = await self._indexer.temporal_index.search(
query=None,
limit=query.limit * 3,
start_time=query.start_time,
end_time=query.end_time,
recent_seconds=query.recent_seconds,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for temporal_entry in temporal_results:
if temporal_entry.memory_id not in candidates:
candidates[temporal_entry.memory_id] = {
"memory_type": temporal_entry.memory_type,
"sources": [],
}
candidates[temporal_entry.memory_id]["timestamp"] = (
temporal_entry.timestamp
)
candidates[temporal_entry.memory_id]["sources"].append("temporal")
# Entity search
if query.use_entity and query.entities:
entity_results = await self._indexer.entity_index.search(
query=None,
limit=query.limit * 3,
entities=query.entities,
match_all=query.entity_match_all,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for entity_entry in entity_results:
if entity_entry.memory_id not in candidates:
candidates[entity_entry.memory_id] = {
"memory_type": entity_entry.memory_type,
"sources": [],
}
# Count entity matches
entity_count = candidates[entity_entry.memory_id].get(
"entity_match_count", 0
)
candidates[entity_entry.memory_id]["entity_match_count"] = (
entity_count + 1
)
candidates[entity_entry.memory_id]["sources"].append("entity")
# Outcome search
if query.use_outcome and query.outcomes:
outcome_results = await self._indexer.outcome_index.search(
query=None,
limit=query.limit * 3,
outcomes=query.outcomes,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for outcome_entry in outcome_results:
if outcome_entry.memory_id not in candidates:
candidates[outcome_entry.memory_id] = {
"memory_type": outcome_entry.memory_type,
"sources": [],
}
candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome
candidates[outcome_entry.memory_id]["sources"].append("outcome")
# Score and rank candidates
scored_results: list[ScoredResult] = []
entity_total = len(query.entities) if query.entities else 1
for memory_id, data in candidates.items():
scored = self._scorer.score(
memory_id=memory_id,
memory_type=data["memory_type"],
vector_similarity=data.get("vector_similarity"),
timestamp=data.get("timestamp"),
entity_match_count=data.get("entity_match_count", 0),
entity_total=entity_total,
outcome=data.get("outcome"),
preferred_outcomes=query.outcomes,
)
scored.metadata["sources"] = data.get("sources", [])
# Filter by minimum relevance
if scored.relevance_score >= query.min_relevance:
scored_results.append(scored)
# Sort by relevance score
scored_results.sort(key=lambda x: x.relevance_score, reverse=True)
# Apply limit
final_results = scored_results[: query.limit]
# Cache results
if use_cache and self._cache and final_results:
self._cache.put(cache_key, final_results)
latency = (_utcnow() - start_time).total_seconds() * 1000
logger.info(
f"Retrieved {len(final_results)} results from {len(candidates)} candidates "
f"in {latency:.2f}ms"
)
return RetrievalResult(
items=final_results,
total_count=len(candidates),
query=query.query_text or "",
retrieval_type="hybrid",
latency_ms=latency,
metadata={
"cache_hit": False,
"candidates_count": len(candidates),
"filtered_count": len(scored_results),
},
)
async def retrieve_similar(
self,
embedding: list[float],
limit: int = 10,
min_similarity: float = 0.5,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve memories similar to a given embedding.
Args:
embedding: Query embedding
limit: Maximum results
min_similarity: Minimum similarity threshold
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
query_embedding=embedding,
limit=limit,
min_relevance=min_similarity,
memory_types=memory_types,
use_temporal=False,
use_entity=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_recent(
self,
hours: float = 24,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve recent memories.
Args:
hours: Number of hours to look back
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
recent_seconds=hours * 3600,
limit=limit,
memory_types=memory_types,
use_vector=False,
use_entity=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_by_entity(
self,
entity_type: str,
entity_value: str,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve memories by entity.
Args:
entity_type: Type of entity
entity_value: Entity value
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
entities=[(entity_type, entity_value)],
limit=limit,
memory_types=memory_types,
use_vector=False,
use_temporal=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_successful(
self,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve successful memories.
Args:
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
outcomes=[Outcome.SUCCESS],
limit=limit,
memory_types=memory_types,
use_vector=False,
use_temporal=False,
use_entity=False,
)
return await self.retrieve(query)
def invalidate_cache(self) -> int:
"""
Invalidate all cached results.
Returns:
Number of entries invalidated
"""
if self._cache:
return self._cache.clear()
return 0
def invalidate_cache_for_memory(self, memory_id: UUID) -> int:
"""
Invalidate cache entries containing a specific memory.
Args:
memory_id: Memory ID to invalidate
Returns:
Number of entries invalidated
"""
if self._cache:
return self._cache.invalidate_by_memory(memory_id)
return 0
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
if self._cache:
return self._cache.get_stats()
return {"enabled": False}
# Singleton retrieval engine instance
_engine: RetrievalEngine | None = None
def get_retrieval_engine() -> RetrievalEngine:
"""Get the singleton retrieval engine instance."""
global _engine
if _engine is None:
_engine = RetrievalEngine()
return _engine