Files
syndarix/backend/app/services/memory/indexing/retrieval.py
Felipe Cardoso 3edce9cd26 fix(memory): address critical bugs from multi-agent review
Bug Fixes:
- Remove singleton pattern from consolidation/reflection services to
  prevent stale database session bugs (session is now passed per-request)
- Add LRU eviction to MemoryToolService._working dict (max 1000 sessions)
  to prevent unbounded memory growth
- Replace O(n) list.remove() with O(1) OrderedDict.move_to_end() in
  RetrievalCache for better performance under load
- Use deque with maxlen for metrics histograms to prevent unbounded
  memory growth (circular buffer with 10k max samples)
- Use full UUID for checkpoint IDs instead of 8-char prefix to avoid
  collision risk at scale (birthday paradox at ~50k checkpoints)

Test Updates:
- Update checkpoint test to expect 36-char UUID
- Update reflection singleton tests to expect new factory behavior
- Add reset_memory_reflection() no-op for backwards compatibility

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 18:55:32 +01:00

743 lines
23 KiB
Python

# app/services/memory/indexing/retrieval.py
"""
Memory Retrieval Engine.
Provides hybrid retrieval capabilities combining:
- Vector similarity search
- Temporal filtering
- Entity filtering
- Outcome filtering
- Relevance scoring
- Result caching
"""
import hashlib
import logging
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any, TypeVar
from uuid import UUID
from app.services.memory.types import (
Episode,
Fact,
MemoryType,
Outcome,
Procedure,
RetrievalResult,
)
from .index import (
MemoryIndexer,
get_memory_indexer,
)
logger = logging.getLogger(__name__)
T = TypeVar("T", Episode, Fact, Procedure)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class RetrievalQuery:
"""Query parameters for memory retrieval."""
# Text/semantic query
query_text: str | None = None
query_embedding: list[float] | None = None
# Temporal filters
start_time: datetime | None = None
end_time: datetime | None = None
recent_seconds: float | None = None
# Entity filters
entities: list[tuple[str, str]] | None = None
entity_match_all: bool = False
# Outcome filters
outcomes: list[Outcome] | None = None
# Memory type filter
memory_types: list[MemoryType] | None = None
# Result options
limit: int = 10
min_relevance: float = 0.0
# Retrieval mode
use_vector: bool = True
use_temporal: bool = True
use_entity: bool = True
use_outcome: bool = True
def to_cache_key(self) -> str:
"""Generate a cache key for this query."""
key_parts = [
self.query_text or "",
str(self.start_time),
str(self.end_time),
str(self.recent_seconds),
str(self.entities),
str(self.outcomes),
str(self.memory_types),
str(self.limit),
str(self.min_relevance),
]
key_string = "|".join(key_parts)
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
@dataclass
class ScoredResult:
"""A retrieval result with relevance score."""
memory_id: UUID
memory_type: MemoryType
relevance_score: float
score_breakdown: dict[str, float] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class CacheEntry:
"""A cached retrieval result."""
results: list[ScoredResult]
created_at: datetime
ttl_seconds: float
query_key: str
def is_expired(self) -> bool:
"""Check if this cache entry has expired."""
age = (_utcnow() - self.created_at).total_seconds()
return age > self.ttl_seconds
class RelevanceScorer:
"""
Calculates relevance scores for retrieved memories.
Combines multiple signals:
- Vector similarity (if available)
- Temporal recency
- Entity match count
- Outcome preference
- Importance/confidence
"""
def __init__(
self,
vector_weight: float = 0.4,
recency_weight: float = 0.2,
entity_weight: float = 0.2,
outcome_weight: float = 0.1,
importance_weight: float = 0.1,
) -> None:
"""
Initialize the relevance scorer.
Args:
vector_weight: Weight for vector similarity (0-1)
recency_weight: Weight for temporal recency (0-1)
entity_weight: Weight for entity matches (0-1)
outcome_weight: Weight for outcome preference (0-1)
importance_weight: Weight for importance score (0-1)
"""
total = (
vector_weight
+ recency_weight
+ entity_weight
+ outcome_weight
+ importance_weight
)
# Normalize weights
self.vector_weight = vector_weight / total
self.recency_weight = recency_weight / total
self.entity_weight = entity_weight / total
self.outcome_weight = outcome_weight / total
self.importance_weight = importance_weight / total
def score(
self,
memory_id: UUID,
memory_type: MemoryType,
vector_similarity: float | None = None,
timestamp: datetime | None = None,
entity_match_count: int = 0,
entity_total: int = 1,
outcome: Outcome | None = None,
importance: float = 0.5,
preferred_outcomes: list[Outcome] | None = None,
) -> ScoredResult:
"""
Calculate a relevance score for a memory.
Args:
memory_id: ID of the memory
memory_type: Type of memory
vector_similarity: Similarity score from vector search (0-1)
timestamp: Timestamp of the memory
entity_match_count: Number of matching entities
entity_total: Total entities in query
outcome: Outcome of the memory
importance: Importance score of the memory (0-1)
preferred_outcomes: Outcomes to prefer
Returns:
Scored result with breakdown
"""
breakdown: dict[str, float] = {}
# Vector similarity score
if vector_similarity is not None:
breakdown["vector"] = vector_similarity
else:
breakdown["vector"] = 0.5 # Neutral if no vector
# Recency score (exponential decay)
if timestamp:
age_hours = (_utcnow() - timestamp).total_seconds() / 3600
# Decay with half-life of 24 hours
breakdown["recency"] = 2 ** (-age_hours / 24)
else:
breakdown["recency"] = 0.5
# Entity match score
if entity_total > 0:
breakdown["entity"] = entity_match_count / entity_total
else:
breakdown["entity"] = 1.0 # No entity filter = full score
# Outcome score
if preferred_outcomes and outcome:
breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0
else:
breakdown["outcome"] = 0.5 # Neutral if no preference
# Importance score
breakdown["importance"] = importance
# Calculate weighted sum
total_score = (
breakdown["vector"] * self.vector_weight
+ breakdown["recency"] * self.recency_weight
+ breakdown["entity"] * self.entity_weight
+ breakdown["outcome"] * self.outcome_weight
+ breakdown["importance"] * self.importance_weight
)
return ScoredResult(
memory_id=memory_id,
memory_type=memory_type,
relevance_score=total_score,
score_breakdown=breakdown,
)
class RetrievalCache:
"""
In-memory cache for retrieval results.
Supports TTL-based expiration and LRU eviction with O(1) operations.
Uses OrderedDict for efficient LRU tracking.
"""
def __init__(
self,
max_entries: int = 1000,
default_ttl_seconds: float = 300,
) -> None:
"""
Initialize the cache.
Args:
max_entries: Maximum cache entries
default_ttl_seconds: Default TTL for entries
"""
# OrderedDict maintains insertion order; we use move_to_end for O(1) LRU
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._max_entries = max_entries
self._default_ttl = default_ttl_seconds
logger.info(
f"Initialized RetrievalCache with max_entries={max_entries}, "
f"ttl={default_ttl_seconds}s"
)
def get(self, query_key: str) -> list[ScoredResult] | None:
"""
Get cached results for a query.
Args:
query_key: Cache key for the query
Returns:
Cached results or None if not found/expired
"""
if query_key not in self._cache:
return None
entry = self._cache[query_key]
if entry.is_expired():
del self._cache[query_key]
return None
# Update access order (LRU) - O(1) with OrderedDict
self._cache.move_to_end(query_key)
logger.debug(f"Cache hit for {query_key}")
return entry.results
def put(
self,
query_key: str,
results: list[ScoredResult],
ttl_seconds: float | None = None,
) -> None:
"""
Cache results for a query.
Args:
query_key: Cache key for the query
results: Results to cache
ttl_seconds: TTL for this entry (or default)
"""
# Evict oldest entries if at capacity - O(1) with popitem(last=False)
while len(self._cache) >= self._max_entries:
self._cache.popitem(last=False)
entry = CacheEntry(
results=results,
created_at=_utcnow(),
ttl_seconds=ttl_seconds or self._default_ttl,
query_key=query_key,
)
self._cache[query_key] = entry
logger.debug(f"Cached {len(results)} results for {query_key}")
def invalidate(self, query_key: str) -> bool:
"""
Invalidate a specific cache entry.
Args:
query_key: Cache key to invalidate
Returns:
True if entry was found and removed
"""
if query_key in self._cache:
del self._cache[query_key]
return True
return False
def invalidate_by_memory(self, memory_id: UUID) -> int:
"""
Invalidate all cache entries containing a specific memory.
Args:
memory_id: Memory ID to invalidate
Returns:
Number of entries invalidated
"""
keys_to_remove = []
for key, entry in self._cache.items():
if any(r.memory_id == memory_id for r in entry.results):
keys_to_remove.append(key)
for key in keys_to_remove:
self.invalidate(key)
if keys_to_remove:
logger.debug(
f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}"
)
return len(keys_to_remove)
def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries cleared
"""
count = len(self._cache)
self._cache.clear()
logger.info(f"Cleared {count} cache entries")
return count
def get_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
expired_count = sum(1 for e in self._cache.values() if e.is_expired())
return {
"total_entries": len(self._cache),
"expired_entries": expired_count,
"max_entries": self._max_entries,
"default_ttl_seconds": self._default_ttl,
}
class RetrievalEngine:
"""
Hybrid retrieval engine for memory search.
Combines multiple index types for comprehensive retrieval:
- Vector search for semantic similarity
- Temporal index for time-based filtering
- Entity index for entity-based lookups
- Outcome index for success/failure filtering
Results are scored and ranked using relevance scoring.
"""
def __init__(
self,
indexer: MemoryIndexer | None = None,
scorer: RelevanceScorer | None = None,
cache: RetrievalCache | None = None,
enable_cache: bool = True,
) -> None:
"""
Initialize the retrieval engine.
Args:
indexer: Memory indexer (defaults to singleton)
scorer: Relevance scorer (defaults to new instance)
cache: Retrieval cache (defaults to new instance)
enable_cache: Whether to enable result caching
"""
self._indexer = indexer or get_memory_indexer()
self._scorer = scorer or RelevanceScorer()
self._cache = cache or RetrievalCache() if enable_cache else None
self._enable_cache = enable_cache
logger.info(f"Initialized RetrievalEngine with cache={enable_cache}")
async def retrieve(
self,
query: RetrievalQuery,
use_cache: bool = True,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve relevant memories using hybrid search.
Args:
query: Retrieval query parameters
use_cache: Whether to use cached results
Returns:
Retrieval result with scored items
"""
start_time = _utcnow()
# Check cache
cache_key = query.to_cache_key()
if use_cache and self._cache:
cached = self._cache.get(cache_key)
if cached:
latency = (_utcnow() - start_time).total_seconds() * 1000
return RetrievalResult(
items=cached,
total_count=len(cached),
query=query.query_text or "",
retrieval_type="cached",
latency_ms=latency,
metadata={"cache_hit": True},
)
# Collect candidates from each index
candidates: dict[UUID, dict[str, Any]] = {}
# Vector search
if query.use_vector and query.query_embedding:
vector_results = await self._indexer.vector_index.search(
query=query.query_embedding,
limit=query.limit * 3, # Get more for filtering
min_similarity=query.min_relevance,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for entry in vector_results:
if entry.memory_id not in candidates:
candidates[entry.memory_id] = {
"memory_type": entry.memory_type,
"sources": [],
}
candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get(
"similarity", 0.5
)
candidates[entry.memory_id]["sources"].append("vector")
# Temporal search
if query.use_temporal and (
query.start_time or query.end_time or query.recent_seconds
):
temporal_results = await self._indexer.temporal_index.search(
query=None,
limit=query.limit * 3,
start_time=query.start_time,
end_time=query.end_time,
recent_seconds=query.recent_seconds,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for temporal_entry in temporal_results:
if temporal_entry.memory_id not in candidates:
candidates[temporal_entry.memory_id] = {
"memory_type": temporal_entry.memory_type,
"sources": [],
}
candidates[temporal_entry.memory_id]["timestamp"] = (
temporal_entry.timestamp
)
candidates[temporal_entry.memory_id]["sources"].append("temporal")
# Entity search
if query.use_entity and query.entities:
entity_results = await self._indexer.entity_index.search(
query=None,
limit=query.limit * 3,
entities=query.entities,
match_all=query.entity_match_all,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for entity_entry in entity_results:
if entity_entry.memory_id not in candidates:
candidates[entity_entry.memory_id] = {
"memory_type": entity_entry.memory_type,
"sources": [],
}
# Count entity matches
entity_count = candidates[entity_entry.memory_id].get(
"entity_match_count", 0
)
candidates[entity_entry.memory_id]["entity_match_count"] = (
entity_count + 1
)
candidates[entity_entry.memory_id]["sources"].append("entity")
# Outcome search
if query.use_outcome and query.outcomes:
outcome_results = await self._indexer.outcome_index.search(
query=None,
limit=query.limit * 3,
outcomes=query.outcomes,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for outcome_entry in outcome_results:
if outcome_entry.memory_id not in candidates:
candidates[outcome_entry.memory_id] = {
"memory_type": outcome_entry.memory_type,
"sources": [],
}
candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome
candidates[outcome_entry.memory_id]["sources"].append("outcome")
# Score and rank candidates
scored_results: list[ScoredResult] = []
entity_total = len(query.entities) if query.entities else 1
for memory_id, data in candidates.items():
scored = self._scorer.score(
memory_id=memory_id,
memory_type=data["memory_type"],
vector_similarity=data.get("vector_similarity"),
timestamp=data.get("timestamp"),
entity_match_count=data.get("entity_match_count", 0),
entity_total=entity_total,
outcome=data.get("outcome"),
preferred_outcomes=query.outcomes,
)
scored.metadata["sources"] = data.get("sources", [])
# Filter by minimum relevance
if scored.relevance_score >= query.min_relevance:
scored_results.append(scored)
# Sort by relevance score
scored_results.sort(key=lambda x: x.relevance_score, reverse=True)
# Apply limit
final_results = scored_results[: query.limit]
# Cache results
if use_cache and self._cache and final_results:
self._cache.put(cache_key, final_results)
latency = (_utcnow() - start_time).total_seconds() * 1000
logger.info(
f"Retrieved {len(final_results)} results from {len(candidates)} candidates "
f"in {latency:.2f}ms"
)
return RetrievalResult(
items=final_results,
total_count=len(candidates),
query=query.query_text or "",
retrieval_type="hybrid",
latency_ms=latency,
metadata={
"cache_hit": False,
"candidates_count": len(candidates),
"filtered_count": len(scored_results),
},
)
async def retrieve_similar(
self,
embedding: list[float],
limit: int = 10,
min_similarity: float = 0.5,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve memories similar to a given embedding.
Args:
embedding: Query embedding
limit: Maximum results
min_similarity: Minimum similarity threshold
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
query_embedding=embedding,
limit=limit,
min_relevance=min_similarity,
memory_types=memory_types,
use_temporal=False,
use_entity=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_recent(
self,
hours: float = 24,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve recent memories.
Args:
hours: Number of hours to look back
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
recent_seconds=hours * 3600,
limit=limit,
memory_types=memory_types,
use_vector=False,
use_entity=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_by_entity(
self,
entity_type: str,
entity_value: str,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve memories by entity.
Args:
entity_type: Type of entity
entity_value: Entity value
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
entities=[(entity_type, entity_value)],
limit=limit,
memory_types=memory_types,
use_vector=False,
use_temporal=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_successful(
self,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve successful memories.
Args:
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
outcomes=[Outcome.SUCCESS],
limit=limit,
memory_types=memory_types,
use_vector=False,
use_temporal=False,
use_entity=False,
)
return await self.retrieve(query)
def invalidate_cache(self) -> int:
"""
Invalidate all cached results.
Returns:
Number of entries invalidated
"""
if self._cache:
return self._cache.clear()
return 0
def invalidate_cache_for_memory(self, memory_id: UUID) -> int:
"""
Invalidate cache entries containing a specific memory.
Args:
memory_id: Memory ID to invalidate
Returns:
Number of entries invalidated
"""
if self._cache:
return self._cache.invalidate_by_memory(memory_id)
return 0
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
if self._cache:
return self._cache.get_stats()
return {"enabled": False}
# Singleton retrieval engine instance
_engine: RetrievalEngine | None = None
def get_retrieval_engine() -> RetrievalEngine:
"""Get the singleton retrieval engine instance."""
global _engine
if _engine is None:
_engine = RetrievalEngine()
return _engine