feat(memory): implement memory indexing and retrieval engine (#94)
Add comprehensive indexing and retrieval system for memory search: - VectorIndex for semantic similarity search using cosine similarity - TemporalIndex for time-based queries with range and recency support - EntityIndex for entity-based lookups with multi-entity intersection - OutcomeIndex for success/failure filtering on episodes - MemoryIndexer as unified interface for all index types - RetrievalEngine with hybrid search combining all indices - RelevanceScorer for multi-signal relevance scoring - RetrievalCache for LRU caching of search results 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
750
backend/app/services/memory/indexing/retrieval.py
Normal file
750
backend/app/services/memory/indexing/retrieval.py
Normal file
@@ -0,0 +1,750 @@
|
||||
# app/services/memory/indexing/retrieval.py
|
||||
"""
|
||||
Memory Retrieval Engine.
|
||||
|
||||
Provides hybrid retrieval capabilities combining:
|
||||
- Vector similarity search
|
||||
- Temporal filtering
|
||||
- Entity filtering
|
||||
- Outcome filtering
|
||||
- Relevance scoring
|
||||
- Result caching
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import (
|
||||
Episode,
|
||||
Fact,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
from .index import (
|
||||
MemoryIndexer,
|
||||
get_memory_indexer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", Episode, Fact, Procedure)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalQuery:
|
||||
"""Query parameters for memory retrieval."""
|
||||
|
||||
# Text/semantic query
|
||||
query_text: str | None = None
|
||||
query_embedding: list[float] | None = None
|
||||
|
||||
# Temporal filters
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
recent_seconds: float | None = None
|
||||
|
||||
# Entity filters
|
||||
entities: list[tuple[str, str]] | None = None
|
||||
entity_match_all: bool = False
|
||||
|
||||
# Outcome filters
|
||||
outcomes: list[Outcome] | None = None
|
||||
|
||||
# Memory type filter
|
||||
memory_types: list[MemoryType] | None = None
|
||||
|
||||
# Result options
|
||||
limit: int = 10
|
||||
min_relevance: float = 0.0
|
||||
|
||||
# Retrieval mode
|
||||
use_vector: bool = True
|
||||
use_temporal: bool = True
|
||||
use_entity: bool = True
|
||||
use_outcome: bool = True
|
||||
|
||||
def to_cache_key(self) -> str:
|
||||
"""Generate a cache key for this query."""
|
||||
key_parts = [
|
||||
self.query_text or "",
|
||||
str(self.start_time),
|
||||
str(self.end_time),
|
||||
str(self.recent_seconds),
|
||||
str(self.entities),
|
||||
str(self.outcomes),
|
||||
str(self.memory_types),
|
||||
str(self.limit),
|
||||
str(self.min_relevance),
|
||||
]
|
||||
key_string = "|".join(key_parts)
|
||||
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredResult:
|
||||
"""A retrieval result with relevance score."""
|
||||
|
||||
memory_id: UUID
|
||||
memory_type: MemoryType
|
||||
relevance_score: float
|
||||
score_breakdown: dict[str, float] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached retrieval result."""
|
||||
|
||||
results: list[ScoredResult]
|
||||
created_at: datetime
|
||||
ttl_seconds: float
|
||||
query_key: str
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this cache entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
|
||||
class RelevanceScorer:
|
||||
"""
|
||||
Calculates relevance scores for retrieved memories.
|
||||
|
||||
Combines multiple signals:
|
||||
- Vector similarity (if available)
|
||||
- Temporal recency
|
||||
- Entity match count
|
||||
- Outcome preference
|
||||
- Importance/confidence
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_weight: float = 0.4,
|
||||
recency_weight: float = 0.2,
|
||||
entity_weight: float = 0.2,
|
||||
outcome_weight: float = 0.1,
|
||||
importance_weight: float = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the relevance scorer.
|
||||
|
||||
Args:
|
||||
vector_weight: Weight for vector similarity (0-1)
|
||||
recency_weight: Weight for temporal recency (0-1)
|
||||
entity_weight: Weight for entity matches (0-1)
|
||||
outcome_weight: Weight for outcome preference (0-1)
|
||||
importance_weight: Weight for importance score (0-1)
|
||||
"""
|
||||
total = (
|
||||
vector_weight
|
||||
+ recency_weight
|
||||
+ entity_weight
|
||||
+ outcome_weight
|
||||
+ importance_weight
|
||||
)
|
||||
# Normalize weights
|
||||
self.vector_weight = vector_weight / total
|
||||
self.recency_weight = recency_weight / total
|
||||
self.entity_weight = entity_weight / total
|
||||
self.outcome_weight = outcome_weight / total
|
||||
self.importance_weight = importance_weight / total
|
||||
|
||||
def score(
|
||||
self,
|
||||
memory_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
vector_similarity: float | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
entity_match_count: int = 0,
|
||||
entity_total: int = 1,
|
||||
outcome: Outcome | None = None,
|
||||
importance: float = 0.5,
|
||||
preferred_outcomes: list[Outcome] | None = None,
|
||||
) -> ScoredResult:
|
||||
"""
|
||||
Calculate a relevance score for a memory.
|
||||
|
||||
Args:
|
||||
memory_id: ID of the memory
|
||||
memory_type: Type of memory
|
||||
vector_similarity: Similarity score from vector search (0-1)
|
||||
timestamp: Timestamp of the memory
|
||||
entity_match_count: Number of matching entities
|
||||
entity_total: Total entities in query
|
||||
outcome: Outcome of the memory
|
||||
importance: Importance score of the memory (0-1)
|
||||
preferred_outcomes: Outcomes to prefer
|
||||
|
||||
Returns:
|
||||
Scored result with breakdown
|
||||
"""
|
||||
breakdown: dict[str, float] = {}
|
||||
|
||||
# Vector similarity score
|
||||
if vector_similarity is not None:
|
||||
breakdown["vector"] = vector_similarity
|
||||
else:
|
||||
breakdown["vector"] = 0.5 # Neutral if no vector
|
||||
|
||||
# Recency score (exponential decay)
|
||||
if timestamp:
|
||||
age_hours = (_utcnow() - timestamp).total_seconds() / 3600
|
||||
# Decay with half-life of 24 hours
|
||||
breakdown["recency"] = 2 ** (-age_hours / 24)
|
||||
else:
|
||||
breakdown["recency"] = 0.5
|
||||
|
||||
# Entity match score
|
||||
if entity_total > 0:
|
||||
breakdown["entity"] = entity_match_count / entity_total
|
||||
else:
|
||||
breakdown["entity"] = 1.0 # No entity filter = full score
|
||||
|
||||
# Outcome score
|
||||
if preferred_outcomes and outcome:
|
||||
breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0
|
||||
else:
|
||||
breakdown["outcome"] = 0.5 # Neutral if no preference
|
||||
|
||||
# Importance score
|
||||
breakdown["importance"] = importance
|
||||
|
||||
# Calculate weighted sum
|
||||
total_score = (
|
||||
breakdown["vector"] * self.vector_weight
|
||||
+ breakdown["recency"] * self.recency_weight
|
||||
+ breakdown["entity"] * self.entity_weight
|
||||
+ breakdown["outcome"] * self.outcome_weight
|
||||
+ breakdown["importance"] * self.importance_weight
|
||||
)
|
||||
|
||||
return ScoredResult(
|
||||
memory_id=memory_id,
|
||||
memory_type=memory_type,
|
||||
relevance_score=total_score,
|
||||
score_breakdown=breakdown,
|
||||
)
|
||||
|
||||
|
||||
class RetrievalCache:
|
||||
"""
|
||||
In-memory cache for retrieval results.
|
||||
|
||||
Supports TTL-based expiration and LRU eviction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_entries: int = 1000,
|
||||
default_ttl_seconds: float = 300,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cache.
|
||||
|
||||
Args:
|
||||
max_entries: Maximum cache entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
"""
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._max_entries = max_entries
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._access_order: list[str] = []
|
||||
logger.info(
|
||||
f"Initialized RetrievalCache with max_entries={max_entries}, "
|
||||
f"ttl={default_ttl_seconds}s"
|
||||
)
|
||||
|
||||
def get(self, query_key: str) -> list[ScoredResult] | None:
|
||||
"""
|
||||
Get cached results for a query.
|
||||
|
||||
Args:
|
||||
query_key: Cache key for the query
|
||||
|
||||
Returns:
|
||||
Cached results or None if not found/expired
|
||||
"""
|
||||
if query_key not in self._cache:
|
||||
return None
|
||||
|
||||
entry = self._cache[query_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[query_key]
|
||||
if query_key in self._access_order:
|
||||
self._access_order.remove(query_key)
|
||||
return None
|
||||
|
||||
# Update access order (LRU)
|
||||
if query_key in self._access_order:
|
||||
self._access_order.remove(query_key)
|
||||
self._access_order.append(query_key)
|
||||
|
||||
logger.debug(f"Cache hit for {query_key}")
|
||||
return entry.results
|
||||
|
||||
def put(
|
||||
self,
|
||||
query_key: str,
|
||||
results: list[ScoredResult],
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache results for a query.
|
||||
|
||||
Args:
|
||||
query_key: Cache key for the query
|
||||
results: Results to cache
|
||||
ttl_seconds: TTL for this entry (or default)
|
||||
"""
|
||||
# Evict if at capacity
|
||||
while len(self._cache) >= self._max_entries and self._access_order:
|
||||
oldest_key = self._access_order.pop(0)
|
||||
if oldest_key in self._cache:
|
||||
del self._cache[oldest_key]
|
||||
|
||||
entry = CacheEntry(
|
||||
results=results,
|
||||
created_at=_utcnow(),
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
query_key=query_key,
|
||||
)
|
||||
|
||||
self._cache[query_key] = entry
|
||||
self._access_order.append(query_key)
|
||||
logger.debug(f"Cached {len(results)} results for {query_key}")
|
||||
|
||||
def invalidate(self, query_key: str) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
query_key: Cache key to invalidate
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
if query_key in self._cache:
|
||||
del self._cache[query_key]
|
||||
if query_key in self._access_order:
|
||||
self._access_order.remove(query_key)
|
||||
return True
|
||||
return False
|
||||
|
||||
def invalidate_by_memory(self, memory_id: UUID) -> int:
|
||||
"""
|
||||
Invalidate all cache entries containing a specific memory.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
keys_to_remove = []
|
||||
for key, entry in self._cache.items():
|
||||
if any(r.memory_id == memory_id for r in entry.results):
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
self.invalidate(key)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(
|
||||
f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}"
|
||||
)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._access_order.clear()
|
||||
logger.info(f"Cleared {count} cache entries")
|
||||
return count
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
expired_count = sum(1 for e in self._cache.values() if e.is_expired())
|
||||
return {
|
||||
"total_entries": len(self._cache),
|
||||
"expired_entries": expired_count,
|
||||
"max_entries": self._max_entries,
|
||||
"default_ttl_seconds": self._default_ttl,
|
||||
}
|
||||
|
||||
|
||||
class RetrievalEngine:
|
||||
"""
|
||||
Hybrid retrieval engine for memory search.
|
||||
|
||||
Combines multiple index types for comprehensive retrieval:
|
||||
- Vector search for semantic similarity
|
||||
- Temporal index for time-based filtering
|
||||
- Entity index for entity-based lookups
|
||||
- Outcome index for success/failure filtering
|
||||
|
||||
Results are scored and ranked using relevance scoring.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indexer: MemoryIndexer | None = None,
|
||||
scorer: RelevanceScorer | None = None,
|
||||
cache: RetrievalCache | None = None,
|
||||
enable_cache: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the retrieval engine.
|
||||
|
||||
Args:
|
||||
indexer: Memory indexer (defaults to singleton)
|
||||
scorer: Relevance scorer (defaults to new instance)
|
||||
cache: Retrieval cache (defaults to new instance)
|
||||
enable_cache: Whether to enable result caching
|
||||
"""
|
||||
self._indexer = indexer or get_memory_indexer()
|
||||
self._scorer = scorer or RelevanceScorer()
|
||||
self._cache = cache or RetrievalCache() if enable_cache else None
|
||||
self._enable_cache = enable_cache
|
||||
logger.info(f"Initialized RetrievalEngine with cache={enable_cache}")
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: RetrievalQuery,
|
||||
use_cache: bool = True,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve relevant memories using hybrid search.
|
||||
|
||||
Args:
|
||||
query: Retrieval query parameters
|
||||
use_cache: Whether to use cached results
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
start_time = _utcnow()
|
||||
|
||||
# Check cache
|
||||
cache_key = query.to_cache_key()
|
||||
if use_cache and self._cache:
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached:
|
||||
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||
return RetrievalResult(
|
||||
items=cached,
|
||||
total_count=len(cached),
|
||||
query=query.query_text or "",
|
||||
retrieval_type="cached",
|
||||
latency_ms=latency,
|
||||
metadata={"cache_hit": True},
|
||||
)
|
||||
|
||||
# Collect candidates from each index
|
||||
candidates: dict[UUID, dict[str, Any]] = {}
|
||||
|
||||
# Vector search
|
||||
if query.use_vector and query.query_embedding:
|
||||
vector_results = await self._indexer.vector_index.search(
|
||||
query=query.query_embedding,
|
||||
limit=query.limit * 3, # Get more for filtering
|
||||
min_similarity=query.min_relevance,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for entry in vector_results:
|
||||
if entry.memory_id not in candidates:
|
||||
candidates[entry.memory_id] = {
|
||||
"memory_type": entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get(
|
||||
"similarity", 0.5
|
||||
)
|
||||
candidates[entry.memory_id]["sources"].append("vector")
|
||||
|
||||
# Temporal search
|
||||
if query.use_temporal and (
|
||||
query.start_time or query.end_time or query.recent_seconds
|
||||
):
|
||||
temporal_results = await self._indexer.temporal_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
start_time=query.start_time,
|
||||
end_time=query.end_time,
|
||||
recent_seconds=query.recent_seconds,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for temporal_entry in temporal_results:
|
||||
if temporal_entry.memory_id not in candidates:
|
||||
candidates[temporal_entry.memory_id] = {
|
||||
"memory_type": temporal_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[temporal_entry.memory_id]["timestamp"] = (
|
||||
temporal_entry.timestamp
|
||||
)
|
||||
candidates[temporal_entry.memory_id]["sources"].append("temporal")
|
||||
|
||||
# Entity search
|
||||
if query.use_entity and query.entities:
|
||||
entity_results = await self._indexer.entity_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
entities=query.entities,
|
||||
match_all=query.entity_match_all,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for entity_entry in entity_results:
|
||||
if entity_entry.memory_id not in candidates:
|
||||
candidates[entity_entry.memory_id] = {
|
||||
"memory_type": entity_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
# Count entity matches
|
||||
entity_count = candidates[entity_entry.memory_id].get(
|
||||
"entity_match_count", 0
|
||||
)
|
||||
candidates[entity_entry.memory_id]["entity_match_count"] = (
|
||||
entity_count + 1
|
||||
)
|
||||
candidates[entity_entry.memory_id]["sources"].append("entity")
|
||||
|
||||
# Outcome search
|
||||
if query.use_outcome and query.outcomes:
|
||||
outcome_results = await self._indexer.outcome_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
outcomes=query.outcomes,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for outcome_entry in outcome_results:
|
||||
if outcome_entry.memory_id not in candidates:
|
||||
candidates[outcome_entry.memory_id] = {
|
||||
"memory_type": outcome_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome
|
||||
candidates[outcome_entry.memory_id]["sources"].append("outcome")
|
||||
|
||||
# Score and rank candidates
|
||||
scored_results: list[ScoredResult] = []
|
||||
entity_total = len(query.entities) if query.entities else 1
|
||||
|
||||
for memory_id, data in candidates.items():
|
||||
scored = self._scorer.score(
|
||||
memory_id=memory_id,
|
||||
memory_type=data["memory_type"],
|
||||
vector_similarity=data.get("vector_similarity"),
|
||||
timestamp=data.get("timestamp"),
|
||||
entity_match_count=data.get("entity_match_count", 0),
|
||||
entity_total=entity_total,
|
||||
outcome=data.get("outcome"),
|
||||
preferred_outcomes=query.outcomes,
|
||||
)
|
||||
scored.metadata["sources"] = data.get("sources", [])
|
||||
|
||||
# Filter by minimum relevance
|
||||
if scored.relevance_score >= query.min_relevance:
|
||||
scored_results.append(scored)
|
||||
|
||||
# Sort by relevance score
|
||||
scored_results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
|
||||
# Apply limit
|
||||
final_results = scored_results[: query.limit]
|
||||
|
||||
# Cache results
|
||||
if use_cache and self._cache and final_results:
|
||||
self._cache.put(cache_key, final_results)
|
||||
|
||||
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Retrieved {len(final_results)} results from {len(candidates)} candidates "
|
||||
f"in {latency:.2f}ms"
|
||||
)
|
||||
|
||||
return RetrievalResult(
|
||||
items=final_results,
|
||||
total_count=len(candidates),
|
||||
query=query.query_text or "",
|
||||
retrieval_type="hybrid",
|
||||
latency_ms=latency,
|
||||
metadata={
|
||||
"cache_hit": False,
|
||||
"candidates_count": len(candidates),
|
||||
"filtered_count": len(scored_results),
|
||||
},
|
||||
)
|
||||
|
||||
async def retrieve_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
limit: int = 10,
|
||||
min_similarity: float = 0.5,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve memories similar to a given embedding.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding
|
||||
limit: Maximum results
|
||||
min_similarity: Minimum similarity threshold
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
query_embedding=embedding,
|
||||
limit=limit,
|
||||
min_relevance=min_similarity,
|
||||
memory_types=memory_types,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_recent(
|
||||
self,
|
||||
hours: float = 24,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve recent memories.
|
||||
|
||||
Args:
|
||||
hours: Number of hours to look back
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
recent_seconds=hours * 3600,
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_by_entity(
|
||||
self,
|
||||
entity_type: str,
|
||||
entity_value: str,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve memories by entity.
|
||||
|
||||
Args:
|
||||
entity_type: Type of entity
|
||||
entity_value: Entity value
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
entities=[(entity_type, entity_value)],
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_temporal=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_successful(
|
||||
self,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve successful memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
outcomes=[Outcome.SUCCESS],
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
def invalidate_cache(self) -> int:
|
||||
"""
|
||||
Invalidate all cached results.
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if self._cache:
|
||||
return self._cache.clear()
|
||||
return 0
|
||||
|
||||
def invalidate_cache_for_memory(self, memory_id: UUID) -> int:
|
||||
"""
|
||||
Invalidate cache entries containing a specific memory.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if self._cache:
|
||||
return self._cache.invalidate_by_memory(memory_id)
|
||||
return 0
|
||||
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if self._cache:
|
||||
return self._cache.get_stats()
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
# Singleton retrieval engine instance
|
||||
_engine: RetrievalEngine | None = None
|
||||
|
||||
|
||||
def get_retrieval_engine() -> RetrievalEngine:
|
||||
"""Get the singleton retrieval engine instance."""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = RetrievalEngine()
|
||||
return _engine
|
||||
Reference in New Issue
Block a user