forked from cardosofelipe/fast-next-template
feat(memory): implement memory indexing and retrieval engine (#94)
Add comprehensive indexing and retrieval system for memory search: - VectorIndex for semantic similarity search using cosine similarity - TemporalIndex for time-based queries with range and recency support - EntityIndex for entity-based lookups with multi-entity intersection - OutcomeIndex for success/failure filtering on episodes - MemoryIndexer as unified interface for all index types - RetrievalEngine with hybrid search combining all indices - RelevanceScorer for multi-signal relevance scoring - RetrievalCache for LRU caching of search results 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,56 @@
|
|||||||
|
# app/services/memory/indexing/__init__.py
|
||||||
"""
|
"""
|
||||||
Memory Indexing
|
Memory Indexing & Retrieval.
|
||||||
|
|
||||||
Vector embeddings and retrieval engine for memory search.
|
Provides vector embeddings and multiple index types for efficient memory search:
|
||||||
|
- Vector index for semantic similarity
|
||||||
|
- Temporal index for time-based queries
|
||||||
|
- Entity index for entity lookups
|
||||||
|
- Outcome index for success/failure filtering
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Will be populated in #94
|
from .index import (
|
||||||
|
EntityIndex,
|
||||||
|
EntityIndexEntry,
|
||||||
|
IndexEntry,
|
||||||
|
MemoryIndex,
|
||||||
|
MemoryIndexer,
|
||||||
|
OutcomeIndex,
|
||||||
|
OutcomeIndexEntry,
|
||||||
|
TemporalIndex,
|
||||||
|
TemporalIndexEntry,
|
||||||
|
VectorIndex,
|
||||||
|
VectorIndexEntry,
|
||||||
|
get_memory_indexer,
|
||||||
|
)
|
||||||
|
from .retrieval import (
|
||||||
|
CacheEntry,
|
||||||
|
RelevanceScorer,
|
||||||
|
RetrievalCache,
|
||||||
|
RetrievalEngine,
|
||||||
|
RetrievalQuery,
|
||||||
|
ScoredResult,
|
||||||
|
get_retrieval_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CacheEntry",
|
||||||
|
"EntityIndex",
|
||||||
|
"EntityIndexEntry",
|
||||||
|
"IndexEntry",
|
||||||
|
"MemoryIndex",
|
||||||
|
"MemoryIndexer",
|
||||||
|
"OutcomeIndex",
|
||||||
|
"OutcomeIndexEntry",
|
||||||
|
"RelevanceScorer",
|
||||||
|
"RetrievalCache",
|
||||||
|
"RetrievalEngine",
|
||||||
|
"RetrievalQuery",
|
||||||
|
"ScoredResult",
|
||||||
|
"TemporalIndex",
|
||||||
|
"TemporalIndexEntry",
|
||||||
|
"VectorIndex",
|
||||||
|
"VectorIndexEntry",
|
||||||
|
"get_memory_indexer",
|
||||||
|
"get_retrieval_engine",
|
||||||
|
]
|
||||||
|
|||||||
851
backend/app/services/memory/indexing/index.py
Normal file
851
backend/app/services/memory/indexing/index.py
Normal file
@@ -0,0 +1,851 @@
|
|||||||
|
# app/services/memory/indexing/index.py
|
||||||
|
"""
|
||||||
|
Memory Indexing.
|
||||||
|
|
||||||
|
Provides multiple indexing strategies for efficient memory retrieval:
|
||||||
|
- Vector embeddings for semantic search
|
||||||
|
- Temporal index for time-based queries
|
||||||
|
- Entity index for entity-based lookups
|
||||||
|
- Outcome index for success/failure filtering
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T", Episode, Fact, Procedure)
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
"""Get current UTC time as timezone-aware datetime."""
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IndexEntry:
|
||||||
|
"""A single entry in an index."""
|
||||||
|
|
||||||
|
memory_id: UUID
|
||||||
|
memory_type: MemoryType
|
||||||
|
indexed_at: datetime = field(default_factory=_utcnow)
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VectorIndexEntry(IndexEntry):
|
||||||
|
"""An entry with vector embedding."""
|
||||||
|
|
||||||
|
embedding: list[float] = field(default_factory=list)
|
||||||
|
dimension: int = 0
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Set dimension from embedding."""
|
||||||
|
if self.embedding:
|
||||||
|
self.dimension = len(self.embedding)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TemporalIndexEntry(IndexEntry):
|
||||||
|
"""An entry indexed by time."""
|
||||||
|
|
||||||
|
timestamp: datetime = field(default_factory=_utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EntityIndexEntry(IndexEntry):
|
||||||
|
"""An entry indexed by entity."""
|
||||||
|
|
||||||
|
entity_type: str = ""
|
||||||
|
entity_value: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OutcomeIndexEntry(IndexEntry):
|
||||||
|
"""An entry indexed by outcome."""
|
||||||
|
|
||||||
|
outcome: Outcome = Outcome.SUCCESS
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryIndex[T](ABC):
|
||||||
|
"""Abstract base class for memory indices."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add(self, item: T) -> IndexEntry:
|
||||||
|
"""Add an item to the index."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def remove(self, memory_id: UUID) -> bool:
|
||||||
|
"""Remove an item from the index."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: Any,
|
||||||
|
limit: int = 10,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[IndexEntry]:
|
||||||
|
"""Search the index."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""Clear all entries from the index."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def count(self) -> int:
|
||||||
|
"""Get the number of entries in the index."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class VectorIndex(MemoryIndex[T]):
|
||||||
|
"""
|
||||||
|
Vector-based index using embeddings for semantic similarity search.
|
||||||
|
|
||||||
|
Uses cosine similarity for matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dimension: int = 1536) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the vector index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dimension: Embedding dimension (default 1536 for OpenAI)
|
||||||
|
"""
|
||||||
|
self._dimension = dimension
|
||||||
|
self._entries: dict[UUID, VectorIndexEntry] = {}
|
||||||
|
logger.info(f"Initialized VectorIndex with dimension={dimension}")
|
||||||
|
|
||||||
|
async def add(self, item: T) -> VectorIndexEntry:
|
||||||
|
"""
|
||||||
|
Add an item to the vector index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: Memory item with embedding
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created index entry
|
||||||
|
"""
|
||||||
|
embedding = getattr(item, "embedding", None) or []
|
||||||
|
|
||||||
|
entry = VectorIndexEntry(
|
||||||
|
memory_id=item.id,
|
||||||
|
memory_type=self._get_memory_type(item),
|
||||||
|
embedding=embedding,
|
||||||
|
dimension=len(embedding),
|
||||||
|
)
|
||||||
|
|
||||||
|
self._entries[item.id] = entry
|
||||||
|
logger.debug(f"Added {item.id} to vector index")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def remove(self, memory_id: UUID) -> bool:
|
||||||
|
"""Remove an item from the vector index."""
|
||||||
|
if memory_id in self._entries:
|
||||||
|
del self._entries[memory_id]
|
||||||
|
logger.debug(f"Removed {memory_id} from vector index")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def search( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
query: Any,
|
||||||
|
limit: int = 10,
|
||||||
|
min_similarity: float = 0.0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[VectorIndexEntry]:
|
||||||
|
"""
|
||||||
|
Search for similar items using vector similarity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query embedding vector
|
||||||
|
limit: Maximum results to return
|
||||||
|
min_similarity: Minimum similarity threshold (0-1)
|
||||||
|
**kwargs: Additional filter parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries sorted by similarity
|
||||||
|
"""
|
||||||
|
if not isinstance(query, list) or not query:
|
||||||
|
return []
|
||||||
|
|
||||||
|
results: list[tuple[float, VectorIndexEntry]] = []
|
||||||
|
|
||||||
|
for entry in self._entries.values():
|
||||||
|
if not entry.embedding:
|
||||||
|
continue
|
||||||
|
|
||||||
|
similarity = self._cosine_similarity(query, entry.embedding)
|
||||||
|
if similarity >= min_similarity:
|
||||||
|
results.append((similarity, entry))
|
||||||
|
|
||||||
|
# Sort by similarity descending
|
||||||
|
results.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
|
||||||
|
# Apply memory type filter if provided
|
||||||
|
memory_type = kwargs.get("memory_type")
|
||||||
|
if memory_type:
|
||||||
|
results = [(s, e) for s, e in results if e.memory_type == memory_type]
|
||||||
|
|
||||||
|
# Store similarity in metadata for the returned entries
|
||||||
|
output = []
|
||||||
|
for similarity, entry in results[:limit]:
|
||||||
|
entry.metadata["similarity"] = similarity
|
||||||
|
output.append(entry)
|
||||||
|
|
||||||
|
logger.debug(f"Vector search returned {len(output)} results")
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""Clear all entries from the index."""
|
||||||
|
count = len(self._entries)
|
||||||
|
self._entries.clear()
|
||||||
|
logger.info(f"Cleared {count} entries from vector index")
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def count(self) -> int:
|
||||||
|
"""Get the number of entries in the index."""
|
||||||
|
return len(self._entries)
|
||||||
|
|
||||||
|
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
|
||||||
|
"""Calculate cosine similarity between two vectors."""
|
||||||
|
if len(a) != len(b) or len(a) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
dot_product = sum(x * y for x, y in zip(a, b, strict=True))
|
||||||
|
norm_a = sum(x * x for x in a) ** 0.5
|
||||||
|
norm_b = sum(x * x for x in b) ** 0.5
|
||||||
|
|
||||||
|
if norm_a == 0 or norm_b == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return dot_product / (norm_a * norm_b)
|
||||||
|
|
||||||
|
def _get_memory_type(self, item: T) -> MemoryType:
|
||||||
|
"""Get the memory type for an item."""
|
||||||
|
if isinstance(item, Episode):
|
||||||
|
return MemoryType.EPISODIC
|
||||||
|
elif isinstance(item, Fact):
|
||||||
|
return MemoryType.SEMANTIC
|
||||||
|
elif isinstance(item, Procedure):
|
||||||
|
return MemoryType.PROCEDURAL
|
||||||
|
return MemoryType.WORKING
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalIndex(MemoryIndex[T]):
|
||||||
|
"""
|
||||||
|
Time-based index for efficient temporal queries.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Range queries (between timestamps)
|
||||||
|
- Recent items (within last N seconds/hours/days)
|
||||||
|
- Oldest/newest sorting
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the temporal index."""
|
||||||
|
self._entries: dict[UUID, TemporalIndexEntry] = {}
|
||||||
|
# Sorted list for efficient range queries
|
||||||
|
self._sorted_entries: list[tuple[datetime, UUID]] = []
|
||||||
|
logger.info("Initialized TemporalIndex")
|
||||||
|
|
||||||
|
async def add(self, item: T) -> TemporalIndexEntry:
|
||||||
|
"""
|
||||||
|
Add an item to the temporal index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: Memory item with timestamp
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created index entry
|
||||||
|
"""
|
||||||
|
# Get timestamp from various possible fields
|
||||||
|
timestamp = self._get_timestamp(item)
|
||||||
|
|
||||||
|
entry = TemporalIndexEntry(
|
||||||
|
memory_id=item.id,
|
||||||
|
memory_type=self._get_memory_type(item),
|
||||||
|
timestamp=timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._entries[item.id] = entry
|
||||||
|
self._insert_sorted(timestamp, item.id)
|
||||||
|
|
||||||
|
logger.debug(f"Added {item.id} to temporal index at {timestamp}")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def remove(self, memory_id: UUID) -> bool:
|
||||||
|
"""Remove an item from the temporal index."""
|
||||||
|
if memory_id not in self._entries:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._entries.pop(memory_id)
|
||||||
|
self._sorted_entries = [
|
||||||
|
(ts, mid) for ts, mid in self._sorted_entries if mid != memory_id
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(f"Removed {memory_id} from temporal index")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def search( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
query: Any,
|
||||||
|
limit: int = 10,
|
||||||
|
start_time: datetime | None = None,
|
||||||
|
end_time: datetime | None = None,
|
||||||
|
recent_seconds: float | None = None,
|
||||||
|
order: str = "desc",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[TemporalIndexEntry]:
|
||||||
|
"""
|
||||||
|
Search for items by time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Ignored for temporal search
|
||||||
|
limit: Maximum results to return
|
||||||
|
start_time: Start of time range
|
||||||
|
end_time: End of time range
|
||||||
|
recent_seconds: Get items from last N seconds
|
||||||
|
order: Sort order ("asc" or "desc")
|
||||||
|
**kwargs: Additional filter parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries sorted by time
|
||||||
|
"""
|
||||||
|
if recent_seconds is not None:
|
||||||
|
start_time = _utcnow() - timedelta(seconds=recent_seconds)
|
||||||
|
end_time = _utcnow()
|
||||||
|
|
||||||
|
# Filter by time range
|
||||||
|
results: list[TemporalIndexEntry] = []
|
||||||
|
for entry in self._entries.values():
|
||||||
|
if start_time and entry.timestamp < start_time:
|
||||||
|
continue
|
||||||
|
if end_time and entry.timestamp > end_time:
|
||||||
|
continue
|
||||||
|
results.append(entry)
|
||||||
|
|
||||||
|
# Apply memory type filter if provided
|
||||||
|
memory_type = kwargs.get("memory_type")
|
||||||
|
if memory_type:
|
||||||
|
results = [e for e in results if e.memory_type == memory_type]
|
||||||
|
|
||||||
|
# Sort by timestamp
|
||||||
|
results.sort(key=lambda e: e.timestamp, reverse=(order == "desc"))
|
||||||
|
|
||||||
|
logger.debug(f"Temporal search returned {min(len(results), limit)} results")
|
||||||
|
return results[:limit]
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""Clear all entries from the index."""
|
||||||
|
count = len(self._entries)
|
||||||
|
self._entries.clear()
|
||||||
|
self._sorted_entries.clear()
|
||||||
|
logger.info(f"Cleared {count} entries from temporal index")
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def count(self) -> int:
|
||||||
|
"""Get the number of entries in the index."""
|
||||||
|
return len(self._entries)
|
||||||
|
|
||||||
|
def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None:
|
||||||
|
"""Insert entry maintaining sorted order."""
|
||||||
|
# Binary search insert for efficiency
|
||||||
|
low, high = 0, len(self._sorted_entries)
|
||||||
|
while low < high:
|
||||||
|
mid = (low + high) // 2
|
||||||
|
if self._sorted_entries[mid][0] < timestamp:
|
||||||
|
low = mid + 1
|
||||||
|
else:
|
||||||
|
high = mid
|
||||||
|
self._sorted_entries.insert(low, (timestamp, memory_id))
|
||||||
|
|
||||||
|
def _get_timestamp(self, item: T) -> datetime:
|
||||||
|
"""Get the relevant timestamp for an item."""
|
||||||
|
if hasattr(item, "occurred_at"):
|
||||||
|
return item.occurred_at
|
||||||
|
if hasattr(item, "first_learned"):
|
||||||
|
return item.first_learned
|
||||||
|
if hasattr(item, "last_used") and item.last_used:
|
||||||
|
return item.last_used
|
||||||
|
if hasattr(item, "created_at"):
|
||||||
|
return item.created_at
|
||||||
|
return _utcnow()
|
||||||
|
|
||||||
|
def _get_memory_type(self, item: T) -> MemoryType:
|
||||||
|
"""Get the memory type for an item."""
|
||||||
|
if isinstance(item, Episode):
|
||||||
|
return MemoryType.EPISODIC
|
||||||
|
elif isinstance(item, Fact):
|
||||||
|
return MemoryType.SEMANTIC
|
||||||
|
elif isinstance(item, Procedure):
|
||||||
|
return MemoryType.PROCEDURAL
|
||||||
|
return MemoryType.WORKING
|
||||||
|
|
||||||
|
|
||||||
|
class EntityIndex(MemoryIndex[T]):
|
||||||
|
"""
|
||||||
|
Entity-based index for lookups by entities mentioned in memories.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Single entity lookup
|
||||||
|
- Multi-entity intersection
|
||||||
|
- Entity type filtering
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the entity index."""
|
||||||
|
# Main storage
|
||||||
|
self._entries: dict[UUID, EntityIndexEntry] = {}
|
||||||
|
# Inverted index: entity -> set of memory IDs
|
||||||
|
self._entity_to_memories: dict[str, set[UUID]] = {}
|
||||||
|
# Memory to entities mapping
|
||||||
|
self._memory_to_entities: dict[UUID, set[str]] = {}
|
||||||
|
logger.info("Initialized EntityIndex")
|
||||||
|
|
||||||
|
async def add(self, item: T) -> EntityIndexEntry:
|
||||||
|
"""
|
||||||
|
Add an item to the entity index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: Memory item with entity information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created index entry
|
||||||
|
"""
|
||||||
|
entities = self._extract_entities(item)
|
||||||
|
|
||||||
|
# Create entry for the primary entity (or first one)
|
||||||
|
primary_entity = entities[0] if entities else ("unknown", "unknown")
|
||||||
|
|
||||||
|
entry = EntityIndexEntry(
|
||||||
|
memory_id=item.id,
|
||||||
|
memory_type=self._get_memory_type(item),
|
||||||
|
entity_type=primary_entity[0],
|
||||||
|
entity_value=primary_entity[1],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._entries[item.id] = entry
|
||||||
|
|
||||||
|
# Update inverted indices
|
||||||
|
entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities}
|
||||||
|
self._memory_to_entities[item.id] = entity_keys
|
||||||
|
|
||||||
|
for entity_key in entity_keys:
|
||||||
|
if entity_key not in self._entity_to_memories:
|
||||||
|
self._entity_to_memories[entity_key] = set()
|
||||||
|
self._entity_to_memories[entity_key].add(item.id)
|
||||||
|
|
||||||
|
logger.debug(f"Added {item.id} to entity index with {len(entities)} entities")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def remove(self, memory_id: UUID) -> bool:
|
||||||
|
"""Remove an item from the entity index."""
|
||||||
|
if memory_id not in self._entries:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Remove from inverted index
|
||||||
|
if memory_id in self._memory_to_entities:
|
||||||
|
for entity_key in self._memory_to_entities[memory_id]:
|
||||||
|
if entity_key in self._entity_to_memories:
|
||||||
|
self._entity_to_memories[entity_key].discard(memory_id)
|
||||||
|
if not self._entity_to_memories[entity_key]:
|
||||||
|
del self._entity_to_memories[entity_key]
|
||||||
|
del self._memory_to_entities[memory_id]
|
||||||
|
|
||||||
|
del self._entries[memory_id]
|
||||||
|
logger.debug(f"Removed {memory_id} from entity index")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def search( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
query: Any,
|
||||||
|
limit: int = 10,
|
||||||
|
entity_type: str | None = None,
|
||||||
|
entity_value: str | None = None,
|
||||||
|
entities: list[tuple[str, str]] | None = None,
|
||||||
|
match_all: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[EntityIndexEntry]:
|
||||||
|
"""
|
||||||
|
Search for items by entity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Entity value to search (if entity_type not specified)
|
||||||
|
limit: Maximum results to return
|
||||||
|
entity_type: Type of entity to filter
|
||||||
|
entity_value: Specific entity value
|
||||||
|
entities: List of (type, value) tuples to match
|
||||||
|
match_all: If True, require all entities to match
|
||||||
|
**kwargs: Additional filter parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries
|
||||||
|
"""
|
||||||
|
matching_ids: set[UUID] | None = None
|
||||||
|
|
||||||
|
# Handle single entity query
|
||||||
|
if entity_type and entity_value:
|
||||||
|
entities = [(entity_type, entity_value)]
|
||||||
|
elif entity_value is None and isinstance(query, str):
|
||||||
|
# Search across all entity types
|
||||||
|
entity_value = query
|
||||||
|
|
||||||
|
if entities:
|
||||||
|
for etype, evalue in entities:
|
||||||
|
entity_key = f"{etype}:{evalue}"
|
||||||
|
if entity_key in self._entity_to_memories:
|
||||||
|
ids = self._entity_to_memories[entity_key]
|
||||||
|
if matching_ids is None:
|
||||||
|
matching_ids = ids.copy()
|
||||||
|
elif match_all:
|
||||||
|
matching_ids &= ids
|
||||||
|
else:
|
||||||
|
matching_ids |= ids
|
||||||
|
elif match_all:
|
||||||
|
# Required entity not found
|
||||||
|
matching_ids = set()
|
||||||
|
break
|
||||||
|
elif entity_value:
|
||||||
|
# Search for value across all types
|
||||||
|
matching_ids = set()
|
||||||
|
for entity_key, ids in self._entity_to_memories.items():
|
||||||
|
if entity_value.lower() in entity_key.lower():
|
||||||
|
matching_ids |= ids
|
||||||
|
|
||||||
|
if matching_ids is None:
|
||||||
|
matching_ids = set(self._entries.keys())
|
||||||
|
|
||||||
|
# Apply memory type filter if provided
|
||||||
|
memory_type = kwargs.get("memory_type")
|
||||||
|
results = []
|
||||||
|
for mid in matching_ids:
|
||||||
|
if mid in self._entries:
|
||||||
|
entry = self._entries[mid]
|
||||||
|
if memory_type and entry.memory_type != memory_type:
|
||||||
|
continue
|
||||||
|
results.append(entry)
|
||||||
|
|
||||||
|
logger.debug(f"Entity search returned {min(len(results), limit)} results")
|
||||||
|
return results[:limit]
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""Clear all entries from the index."""
|
||||||
|
count = len(self._entries)
|
||||||
|
self._entries.clear()
|
||||||
|
self._entity_to_memories.clear()
|
||||||
|
self._memory_to_entities.clear()
|
||||||
|
logger.info(f"Cleared {count} entries from entity index")
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def count(self) -> int:
|
||||||
|
"""Get the number of entries in the index."""
|
||||||
|
return len(self._entries)
|
||||||
|
|
||||||
|
async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]:
|
||||||
|
"""Get all entities for a memory item."""
|
||||||
|
if memory_id not in self._memory_to_entities:
|
||||||
|
return []
|
||||||
|
|
||||||
|
entities = []
|
||||||
|
for entity_key in self._memory_to_entities[memory_id]:
|
||||||
|
if ":" in entity_key:
|
||||||
|
etype, evalue = entity_key.split(":", 1)
|
||||||
|
entities.append((etype, evalue))
|
||||||
|
return entities
|
||||||
|
|
||||||
|
def _extract_entities(self, item: T) -> list[tuple[str, str]]:
|
||||||
|
"""Extract entities from a memory item."""
|
||||||
|
entities: list[tuple[str, str]] = []
|
||||||
|
|
||||||
|
if isinstance(item, Episode):
|
||||||
|
# Extract from task type and context
|
||||||
|
entities.append(("task_type", item.task_type))
|
||||||
|
if item.project_id:
|
||||||
|
entities.append(("project", str(item.project_id)))
|
||||||
|
if item.agent_instance_id:
|
||||||
|
entities.append(("agent_instance", str(item.agent_instance_id)))
|
||||||
|
if item.agent_type_id:
|
||||||
|
entities.append(("agent_type", str(item.agent_type_id)))
|
||||||
|
|
||||||
|
elif isinstance(item, Fact):
|
||||||
|
# Subject and object are entities
|
||||||
|
entities.append(("subject", item.subject))
|
||||||
|
entities.append(("object", item.object))
|
||||||
|
if item.project_id:
|
||||||
|
entities.append(("project", str(item.project_id)))
|
||||||
|
|
||||||
|
elif isinstance(item, Procedure):
|
||||||
|
entities.append(("procedure", item.name))
|
||||||
|
if item.project_id:
|
||||||
|
entities.append(("project", str(item.project_id)))
|
||||||
|
if item.agent_type_id:
|
||||||
|
entities.append(("agent_type", str(item.agent_type_id)))
|
||||||
|
|
||||||
|
return entities
|
||||||
|
|
||||||
|
def _get_memory_type(self, item: T) -> MemoryType:
|
||||||
|
"""Get the memory type for an item."""
|
||||||
|
if isinstance(item, Episode):
|
||||||
|
return MemoryType.EPISODIC
|
||||||
|
elif isinstance(item, Fact):
|
||||||
|
return MemoryType.SEMANTIC
|
||||||
|
elif isinstance(item, Procedure):
|
||||||
|
return MemoryType.PROCEDURAL
|
||||||
|
return MemoryType.WORKING
|
||||||
|
|
||||||
|
|
||||||
|
class OutcomeIndex(MemoryIndex[T]):
|
||||||
|
"""
|
||||||
|
Outcome-based index for filtering by success/failure.
|
||||||
|
|
||||||
|
Primarily used for episodes and procedures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Initialize the outcome index."""
|
||||||
|
self._entries: dict[UUID, OutcomeIndexEntry] = {}
|
||||||
|
# Inverted index by outcome
|
||||||
|
self._outcome_to_memories: dict[Outcome, set[UUID]] = {
|
||||||
|
Outcome.SUCCESS: set(),
|
||||||
|
Outcome.FAILURE: set(),
|
||||||
|
Outcome.PARTIAL: set(),
|
||||||
|
}
|
||||||
|
logger.info("Initialized OutcomeIndex")
|
||||||
|
|
||||||
|
async def add(self, item: T) -> OutcomeIndexEntry:
|
||||||
|
"""
|
||||||
|
Add an item to the outcome index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: Memory item with outcome information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created index entry
|
||||||
|
"""
|
||||||
|
outcome = self._get_outcome(item)
|
||||||
|
|
||||||
|
entry = OutcomeIndexEntry(
|
||||||
|
memory_id=item.id,
|
||||||
|
memory_type=self._get_memory_type(item),
|
||||||
|
outcome=outcome,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._entries[item.id] = entry
|
||||||
|
self._outcome_to_memories[outcome].add(item.id)
|
||||||
|
|
||||||
|
logger.debug(f"Added {item.id} to outcome index with {outcome.value}")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def remove(self, memory_id: UUID) -> bool:
|
||||||
|
"""Remove an item from the outcome index."""
|
||||||
|
if memory_id not in self._entries:
|
||||||
|
return False
|
||||||
|
|
||||||
|
entry = self._entries.pop(memory_id)
|
||||||
|
self._outcome_to_memories[entry.outcome].discard(memory_id)
|
||||||
|
|
||||||
|
logger.debug(f"Removed {memory_id} from outcome index")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def search( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
query: Any,
|
||||||
|
limit: int = 10,
|
||||||
|
outcome: Outcome | None = None,
|
||||||
|
outcomes: list[Outcome] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[OutcomeIndexEntry]:
|
||||||
|
"""
|
||||||
|
Search for items by outcome.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Ignored for outcome search
|
||||||
|
limit: Maximum results to return
|
||||||
|
outcome: Single outcome to filter
|
||||||
|
outcomes: Multiple outcomes to filter (OR)
|
||||||
|
**kwargs: Additional filter parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching entries
|
||||||
|
"""
|
||||||
|
if outcome:
|
||||||
|
outcomes = [outcome]
|
||||||
|
|
||||||
|
if outcomes:
|
||||||
|
matching_ids: set[UUID] = set()
|
||||||
|
for o in outcomes:
|
||||||
|
matching_ids |= self._outcome_to_memories.get(o, set())
|
||||||
|
else:
|
||||||
|
matching_ids = set(self._entries.keys())
|
||||||
|
|
||||||
|
# Apply memory type filter if provided
|
||||||
|
memory_type = kwargs.get("memory_type")
|
||||||
|
results = []
|
||||||
|
for mid in matching_ids:
|
||||||
|
if mid in self._entries:
|
||||||
|
entry = self._entries[mid]
|
||||||
|
if memory_type and entry.memory_type != memory_type:
|
||||||
|
continue
|
||||||
|
results.append(entry)
|
||||||
|
|
||||||
|
logger.debug(f"Outcome search returned {min(len(results), limit)} results")
|
||||||
|
return results[:limit]
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""Clear all entries from the index."""
|
||||||
|
count = len(self._entries)
|
||||||
|
self._entries.clear()
|
||||||
|
for outcome in self._outcome_to_memories:
|
||||||
|
self._outcome_to_memories[outcome].clear()
|
||||||
|
logger.info(f"Cleared {count} entries from outcome index")
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def count(self) -> int:
|
||||||
|
"""Get the number of entries in the index."""
|
||||||
|
return len(self._entries)
|
||||||
|
|
||||||
|
async def get_outcome_stats(self) -> dict[Outcome, int]:
|
||||||
|
"""Get statistics on outcomes."""
|
||||||
|
return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()}
|
||||||
|
|
||||||
|
def _get_outcome(self, item: T) -> Outcome:
|
||||||
|
"""Get the outcome for an item."""
|
||||||
|
if isinstance(item, Episode):
|
||||||
|
return item.outcome
|
||||||
|
elif isinstance(item, Procedure):
|
||||||
|
# Derive from success rate
|
||||||
|
if item.success_rate >= 0.8:
|
||||||
|
return Outcome.SUCCESS
|
||||||
|
elif item.success_rate <= 0.2:
|
||||||
|
return Outcome.FAILURE
|
||||||
|
return Outcome.PARTIAL
|
||||||
|
return Outcome.SUCCESS
|
||||||
|
|
||||||
|
def _get_memory_type(self, item: T) -> MemoryType:
|
||||||
|
"""Get the memory type for an item."""
|
||||||
|
if isinstance(item, Episode):
|
||||||
|
return MemoryType.EPISODIC
|
||||||
|
elif isinstance(item, Fact):
|
||||||
|
return MemoryType.SEMANTIC
|
||||||
|
elif isinstance(item, Procedure):
|
||||||
|
return MemoryType.PROCEDURAL
|
||||||
|
return MemoryType.WORKING
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryIndexer:
|
||||||
|
"""
|
||||||
|
Unified indexer that manages all index types.
|
||||||
|
|
||||||
|
Provides a single interface for indexing and searching across
|
||||||
|
multiple index types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_index: VectorIndex[Any] = field(default_factory=VectorIndex)
|
||||||
|
temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex)
|
||||||
|
entity_index: EntityIndex[Any] = field(default_factory=EntityIndex)
|
||||||
|
outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex)
|
||||||
|
|
||||||
|
async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]:
|
||||||
|
"""
|
||||||
|
Index an item across all applicable indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: Memory item to index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of index type to entry
|
||||||
|
"""
|
||||||
|
results: dict[str, IndexEntry] = {}
|
||||||
|
|
||||||
|
# Vector index (if embedding present)
|
||||||
|
if getattr(item, "embedding", None):
|
||||||
|
results["vector"] = await self.vector_index.add(item)
|
||||||
|
|
||||||
|
# Temporal index
|
||||||
|
results["temporal"] = await self.temporal_index.add(item)
|
||||||
|
|
||||||
|
# Entity index
|
||||||
|
results["entity"] = await self.entity_index.add(item)
|
||||||
|
|
||||||
|
# Outcome index (for episodes and procedures)
|
||||||
|
if isinstance(item, (Episode, Procedure)):
|
||||||
|
results["outcome"] = await self.outcome_index.add(item)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}"
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def remove(self, memory_id: UUID) -> dict[str, bool]:
|
||||||
|
"""
|
||||||
|
Remove an item from all indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: ID of the memory to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of index type to removal success
|
||||||
|
"""
|
||||||
|
results = {
|
||||||
|
"vector": await self.vector_index.remove(memory_id),
|
||||||
|
"temporal": await self.temporal_index.remove(memory_id),
|
||||||
|
"entity": await self.entity_index.remove(memory_id),
|
||||||
|
"outcome": await self.outcome_index.remove(memory_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
removed_from = [k for k, v in results.items() if v]
|
||||||
|
if removed_from:
|
||||||
|
logger.info(f"Removed {memory_id} from indices: {removed_from}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def clear_all(self) -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Clear all indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of index type to count cleared
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"vector": await self.vector_index.clear(),
|
||||||
|
"temporal": await self.temporal_index.clear(),
|
||||||
|
"entity": await self.entity_index.clear(),
|
||||||
|
"outcome": await self.outcome_index.clear(),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_stats(self) -> dict[str, int]:
|
||||||
|
"""
|
||||||
|
Get statistics for all indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of index type to entry count
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"vector": await self.vector_index.count(),
|
||||||
|
"temporal": await self.temporal_index.count(),
|
||||||
|
"entity": await self.entity_index.count(),
|
||||||
|
"outcome": await self.outcome_index.count(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton indexer instance
|
||||||
|
_indexer: MemoryIndexer | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_memory_indexer() -> MemoryIndexer:
|
||||||
|
"""Get the singleton memory indexer instance."""
|
||||||
|
global _indexer
|
||||||
|
if _indexer is None:
|
||||||
|
_indexer = MemoryIndexer()
|
||||||
|
return _indexer
|
||||||
750
backend/app/services/memory/indexing/retrieval.py
Normal file
750
backend/app/services/memory/indexing/retrieval.py
Normal file
@@ -0,0 +1,750 @@
|
|||||||
|
# app/services/memory/indexing/retrieval.py
|
||||||
|
"""
|
||||||
|
Memory Retrieval Engine.
|
||||||
|
|
||||||
|
Provides hybrid retrieval capabilities combining:
|
||||||
|
- Vector similarity search
|
||||||
|
- Temporal filtering
|
||||||
|
- Entity filtering
|
||||||
|
- Outcome filtering
|
||||||
|
- Relevance scoring
|
||||||
|
- Result caching
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from app.services.memory.types import (
|
||||||
|
Episode,
|
||||||
|
Fact,
|
||||||
|
MemoryType,
|
||||||
|
Outcome,
|
||||||
|
Procedure,
|
||||||
|
RetrievalResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .index import (
|
||||||
|
MemoryIndexer,
|
||||||
|
get_memory_indexer,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T", Episode, Fact, Procedure)
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
"""Get current UTC time as timezone-aware datetime."""
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalQuery:
|
||||||
|
"""Query parameters for memory retrieval."""
|
||||||
|
|
||||||
|
# Text/semantic query
|
||||||
|
query_text: str | None = None
|
||||||
|
query_embedding: list[float] | None = None
|
||||||
|
|
||||||
|
# Temporal filters
|
||||||
|
start_time: datetime | None = None
|
||||||
|
end_time: datetime | None = None
|
||||||
|
recent_seconds: float | None = None
|
||||||
|
|
||||||
|
# Entity filters
|
||||||
|
entities: list[tuple[str, str]] | None = None
|
||||||
|
entity_match_all: bool = False
|
||||||
|
|
||||||
|
# Outcome filters
|
||||||
|
outcomes: list[Outcome] | None = None
|
||||||
|
|
||||||
|
# Memory type filter
|
||||||
|
memory_types: list[MemoryType] | None = None
|
||||||
|
|
||||||
|
# Result options
|
||||||
|
limit: int = 10
|
||||||
|
min_relevance: float = 0.0
|
||||||
|
|
||||||
|
# Retrieval mode
|
||||||
|
use_vector: bool = True
|
||||||
|
use_temporal: bool = True
|
||||||
|
use_entity: bool = True
|
||||||
|
use_outcome: bool = True
|
||||||
|
|
||||||
|
def to_cache_key(self) -> str:
|
||||||
|
"""Generate a cache key for this query."""
|
||||||
|
key_parts = [
|
||||||
|
self.query_text or "",
|
||||||
|
str(self.start_time),
|
||||||
|
str(self.end_time),
|
||||||
|
str(self.recent_seconds),
|
||||||
|
str(self.entities),
|
||||||
|
str(self.outcomes),
|
||||||
|
str(self.memory_types),
|
||||||
|
str(self.limit),
|
||||||
|
str(self.min_relevance),
|
||||||
|
]
|
||||||
|
key_string = "|".join(key_parts)
|
||||||
|
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoredResult:
|
||||||
|
"""A retrieval result with relevance score."""
|
||||||
|
|
||||||
|
memory_id: UUID
|
||||||
|
memory_type: MemoryType
|
||||||
|
relevance_score: float
|
||||||
|
score_breakdown: dict[str, float] = field(default_factory=dict)
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheEntry:
|
||||||
|
"""A cached retrieval result."""
|
||||||
|
|
||||||
|
results: list[ScoredResult]
|
||||||
|
created_at: datetime
|
||||||
|
ttl_seconds: float
|
||||||
|
query_key: str
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
"""Check if this cache entry has expired."""
|
||||||
|
age = (_utcnow() - self.created_at).total_seconds()
|
||||||
|
return age > self.ttl_seconds
|
||||||
|
|
||||||
|
|
||||||
|
class RelevanceScorer:
|
||||||
|
"""
|
||||||
|
Calculates relevance scores for retrieved memories.
|
||||||
|
|
||||||
|
Combines multiple signals:
|
||||||
|
- Vector similarity (if available)
|
||||||
|
- Temporal recency
|
||||||
|
- Entity match count
|
||||||
|
- Outcome preference
|
||||||
|
- Importance/confidence
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_weight: float = 0.4,
|
||||||
|
recency_weight: float = 0.2,
|
||||||
|
entity_weight: float = 0.2,
|
||||||
|
outcome_weight: float = 0.1,
|
||||||
|
importance_weight: float = 0.1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the relevance scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_weight: Weight for vector similarity (0-1)
|
||||||
|
recency_weight: Weight for temporal recency (0-1)
|
||||||
|
entity_weight: Weight for entity matches (0-1)
|
||||||
|
outcome_weight: Weight for outcome preference (0-1)
|
||||||
|
importance_weight: Weight for importance score (0-1)
|
||||||
|
"""
|
||||||
|
total = (
|
||||||
|
vector_weight
|
||||||
|
+ recency_weight
|
||||||
|
+ entity_weight
|
||||||
|
+ outcome_weight
|
||||||
|
+ importance_weight
|
||||||
|
)
|
||||||
|
# Normalize weights
|
||||||
|
self.vector_weight = vector_weight / total
|
||||||
|
self.recency_weight = recency_weight / total
|
||||||
|
self.entity_weight = entity_weight / total
|
||||||
|
self.outcome_weight = outcome_weight / total
|
||||||
|
self.importance_weight = importance_weight / total
|
||||||
|
|
||||||
|
def score(
|
||||||
|
self,
|
||||||
|
memory_id: UUID,
|
||||||
|
memory_type: MemoryType,
|
||||||
|
vector_similarity: float | None = None,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
entity_match_count: int = 0,
|
||||||
|
entity_total: int = 1,
|
||||||
|
outcome: Outcome | None = None,
|
||||||
|
importance: float = 0.5,
|
||||||
|
preferred_outcomes: list[Outcome] | None = None,
|
||||||
|
) -> ScoredResult:
|
||||||
|
"""
|
||||||
|
Calculate a relevance score for a memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: ID of the memory
|
||||||
|
memory_type: Type of memory
|
||||||
|
vector_similarity: Similarity score from vector search (0-1)
|
||||||
|
timestamp: Timestamp of the memory
|
||||||
|
entity_match_count: Number of matching entities
|
||||||
|
entity_total: Total entities in query
|
||||||
|
outcome: Outcome of the memory
|
||||||
|
importance: Importance score of the memory (0-1)
|
||||||
|
preferred_outcomes: Outcomes to prefer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scored result with breakdown
|
||||||
|
"""
|
||||||
|
breakdown: dict[str, float] = {}
|
||||||
|
|
||||||
|
# Vector similarity score
|
||||||
|
if vector_similarity is not None:
|
||||||
|
breakdown["vector"] = vector_similarity
|
||||||
|
else:
|
||||||
|
breakdown["vector"] = 0.5 # Neutral if no vector
|
||||||
|
|
||||||
|
# Recency score (exponential decay)
|
||||||
|
if timestamp:
|
||||||
|
age_hours = (_utcnow() - timestamp).total_seconds() / 3600
|
||||||
|
# Decay with half-life of 24 hours
|
||||||
|
breakdown["recency"] = 2 ** (-age_hours / 24)
|
||||||
|
else:
|
||||||
|
breakdown["recency"] = 0.5
|
||||||
|
|
||||||
|
# Entity match score
|
||||||
|
if entity_total > 0:
|
||||||
|
breakdown["entity"] = entity_match_count / entity_total
|
||||||
|
else:
|
||||||
|
breakdown["entity"] = 1.0 # No entity filter = full score
|
||||||
|
|
||||||
|
# Outcome score
|
||||||
|
if preferred_outcomes and outcome:
|
||||||
|
breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0
|
||||||
|
else:
|
||||||
|
breakdown["outcome"] = 0.5 # Neutral if no preference
|
||||||
|
|
||||||
|
# Importance score
|
||||||
|
breakdown["importance"] = importance
|
||||||
|
|
||||||
|
# Calculate weighted sum
|
||||||
|
total_score = (
|
||||||
|
breakdown["vector"] * self.vector_weight
|
||||||
|
+ breakdown["recency"] * self.recency_weight
|
||||||
|
+ breakdown["entity"] * self.entity_weight
|
||||||
|
+ breakdown["outcome"] * self.outcome_weight
|
||||||
|
+ breakdown["importance"] * self.importance_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
return ScoredResult(
|
||||||
|
memory_id=memory_id,
|
||||||
|
memory_type=memory_type,
|
||||||
|
relevance_score=total_score,
|
||||||
|
score_breakdown=breakdown,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalCache:
|
||||||
|
"""
|
||||||
|
In-memory cache for retrieval results.
|
||||||
|
|
||||||
|
Supports TTL-based expiration and LRU eviction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_entries: int = 1000,
|
||||||
|
default_ttl_seconds: float = 300,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_entries: Maximum cache entries
|
||||||
|
default_ttl_seconds: Default TTL for entries
|
||||||
|
"""
|
||||||
|
self._cache: dict[str, CacheEntry] = {}
|
||||||
|
self._max_entries = max_entries
|
||||||
|
self._default_ttl = default_ttl_seconds
|
||||||
|
self._access_order: list[str] = []
|
||||||
|
logger.info(
|
||||||
|
f"Initialized RetrievalCache with max_entries={max_entries}, "
|
||||||
|
f"ttl={default_ttl_seconds}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(self, query_key: str) -> list[ScoredResult] | None:
|
||||||
|
"""
|
||||||
|
Get cached results for a query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_key: Cache key for the query
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached results or None if not found/expired
|
||||||
|
"""
|
||||||
|
if query_key not in self._cache:
|
||||||
|
return None
|
||||||
|
|
||||||
|
entry = self._cache[query_key]
|
||||||
|
if entry.is_expired():
|
||||||
|
del self._cache[query_key]
|
||||||
|
if query_key in self._access_order:
|
||||||
|
self._access_order.remove(query_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update access order (LRU)
|
||||||
|
if query_key in self._access_order:
|
||||||
|
self._access_order.remove(query_key)
|
||||||
|
self._access_order.append(query_key)
|
||||||
|
|
||||||
|
logger.debug(f"Cache hit for {query_key}")
|
||||||
|
return entry.results
|
||||||
|
|
||||||
|
def put(
|
||||||
|
self,
|
||||||
|
query_key: str,
|
||||||
|
results: list[ScoredResult],
|
||||||
|
ttl_seconds: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache results for a query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_key: Cache key for the query
|
||||||
|
results: Results to cache
|
||||||
|
ttl_seconds: TTL for this entry (or default)
|
||||||
|
"""
|
||||||
|
# Evict if at capacity
|
||||||
|
while len(self._cache) >= self._max_entries and self._access_order:
|
||||||
|
oldest_key = self._access_order.pop(0)
|
||||||
|
if oldest_key in self._cache:
|
||||||
|
del self._cache[oldest_key]
|
||||||
|
|
||||||
|
entry = CacheEntry(
|
||||||
|
results=results,
|
||||||
|
created_at=_utcnow(),
|
||||||
|
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||||
|
query_key=query_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._cache[query_key] = entry
|
||||||
|
self._access_order.append(query_key)
|
||||||
|
logger.debug(f"Cached {len(results)} results for {query_key}")
|
||||||
|
|
||||||
|
def invalidate(self, query_key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Invalidate a specific cache entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_key: Cache key to invalidate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if entry was found and removed
|
||||||
|
"""
|
||||||
|
if query_key in self._cache:
|
||||||
|
del self._cache[query_key]
|
||||||
|
if query_key in self._access_order:
|
||||||
|
self._access_order.remove(query_key)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def invalidate_by_memory(self, memory_id: UUID) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate all cache entries containing a specific memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: Memory ID to invalidate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
keys_to_remove = []
|
||||||
|
for key, entry in self._cache.items():
|
||||||
|
if any(r.memory_id == memory_id for r in entry.results):
|
||||||
|
keys_to_remove.append(key)
|
||||||
|
|
||||||
|
for key in keys_to_remove:
|
||||||
|
self.invalidate(key)
|
||||||
|
|
||||||
|
if keys_to_remove:
|
||||||
|
logger.debug(
|
||||||
|
f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}"
|
||||||
|
)
|
||||||
|
return len(keys_to_remove)
|
||||||
|
|
||||||
|
def clear(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear all cache entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries cleared
|
||||||
|
"""
|
||||||
|
count = len(self._cache)
|
||||||
|
self._cache.clear()
|
||||||
|
self._access_order.clear()
|
||||||
|
logger.info(f"Cleared {count} cache entries")
|
||||||
|
return count
|
||||||
|
|
||||||
|
def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get cache statistics."""
|
||||||
|
expired_count = sum(1 for e in self._cache.values() if e.is_expired())
|
||||||
|
return {
|
||||||
|
"total_entries": len(self._cache),
|
||||||
|
"expired_entries": expired_count,
|
||||||
|
"max_entries": self._max_entries,
|
||||||
|
"default_ttl_seconds": self._default_ttl,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalEngine:
|
||||||
|
"""
|
||||||
|
Hybrid retrieval engine for memory search.
|
||||||
|
|
||||||
|
Combines multiple index types for comprehensive retrieval:
|
||||||
|
- Vector search for semantic similarity
|
||||||
|
- Temporal index for time-based filtering
|
||||||
|
- Entity index for entity-based lookups
|
||||||
|
- Outcome index for success/failure filtering
|
||||||
|
|
||||||
|
Results are scored and ranked using relevance scoring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
indexer: MemoryIndexer | None = None,
|
||||||
|
scorer: RelevanceScorer | None = None,
|
||||||
|
cache: RetrievalCache | None = None,
|
||||||
|
enable_cache: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the retrieval engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
indexer: Memory indexer (defaults to singleton)
|
||||||
|
scorer: Relevance scorer (defaults to new instance)
|
||||||
|
cache: Retrieval cache (defaults to new instance)
|
||||||
|
enable_cache: Whether to enable result caching
|
||||||
|
"""
|
||||||
|
self._indexer = indexer or get_memory_indexer()
|
||||||
|
self._scorer = scorer or RelevanceScorer()
|
||||||
|
self._cache = cache or RetrievalCache() if enable_cache else None
|
||||||
|
self._enable_cache = enable_cache
|
||||||
|
logger.info(f"Initialized RetrievalEngine with cache={enable_cache}")
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
query: RetrievalQuery,
|
||||||
|
use_cache: bool = True,
|
||||||
|
) -> RetrievalResult[ScoredResult]:
|
||||||
|
"""
|
||||||
|
Retrieve relevant memories using hybrid search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Retrieval query parameters
|
||||||
|
use_cache: Whether to use cached results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Retrieval result with scored items
|
||||||
|
"""
|
||||||
|
start_time = _utcnow()
|
||||||
|
|
||||||
|
# Check cache
|
||||||
|
cache_key = query.to_cache_key()
|
||||||
|
if use_cache and self._cache:
|
||||||
|
cached = self._cache.get(cache_key)
|
||||||
|
if cached:
|
||||||
|
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||||
|
return RetrievalResult(
|
||||||
|
items=cached,
|
||||||
|
total_count=len(cached),
|
||||||
|
query=query.query_text or "",
|
||||||
|
retrieval_type="cached",
|
||||||
|
latency_ms=latency,
|
||||||
|
metadata={"cache_hit": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect candidates from each index
|
||||||
|
candidates: dict[UUID, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# Vector search
|
||||||
|
if query.use_vector and query.query_embedding:
|
||||||
|
vector_results = await self._indexer.vector_index.search(
|
||||||
|
query=query.query_embedding,
|
||||||
|
limit=query.limit * 3, # Get more for filtering
|
||||||
|
min_similarity=query.min_relevance,
|
||||||
|
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||||
|
)
|
||||||
|
for entry in vector_results:
|
||||||
|
if entry.memory_id not in candidates:
|
||||||
|
candidates[entry.memory_id] = {
|
||||||
|
"memory_type": entry.memory_type,
|
||||||
|
"sources": [],
|
||||||
|
}
|
||||||
|
candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get(
|
||||||
|
"similarity", 0.5
|
||||||
|
)
|
||||||
|
candidates[entry.memory_id]["sources"].append("vector")
|
||||||
|
|
||||||
|
# Temporal search
|
||||||
|
if query.use_temporal and (
|
||||||
|
query.start_time or query.end_time or query.recent_seconds
|
||||||
|
):
|
||||||
|
temporal_results = await self._indexer.temporal_index.search(
|
||||||
|
query=None,
|
||||||
|
limit=query.limit * 3,
|
||||||
|
start_time=query.start_time,
|
||||||
|
end_time=query.end_time,
|
||||||
|
recent_seconds=query.recent_seconds,
|
||||||
|
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||||
|
)
|
||||||
|
for temporal_entry in temporal_results:
|
||||||
|
if temporal_entry.memory_id not in candidates:
|
||||||
|
candidates[temporal_entry.memory_id] = {
|
||||||
|
"memory_type": temporal_entry.memory_type,
|
||||||
|
"sources": [],
|
||||||
|
}
|
||||||
|
candidates[temporal_entry.memory_id]["timestamp"] = (
|
||||||
|
temporal_entry.timestamp
|
||||||
|
)
|
||||||
|
candidates[temporal_entry.memory_id]["sources"].append("temporal")
|
||||||
|
|
||||||
|
# Entity search
|
||||||
|
if query.use_entity and query.entities:
|
||||||
|
entity_results = await self._indexer.entity_index.search(
|
||||||
|
query=None,
|
||||||
|
limit=query.limit * 3,
|
||||||
|
entities=query.entities,
|
||||||
|
match_all=query.entity_match_all,
|
||||||
|
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||||
|
)
|
||||||
|
for entity_entry in entity_results:
|
||||||
|
if entity_entry.memory_id not in candidates:
|
||||||
|
candidates[entity_entry.memory_id] = {
|
||||||
|
"memory_type": entity_entry.memory_type,
|
||||||
|
"sources": [],
|
||||||
|
}
|
||||||
|
# Count entity matches
|
||||||
|
entity_count = candidates[entity_entry.memory_id].get(
|
||||||
|
"entity_match_count", 0
|
||||||
|
)
|
||||||
|
candidates[entity_entry.memory_id]["entity_match_count"] = (
|
||||||
|
entity_count + 1
|
||||||
|
)
|
||||||
|
candidates[entity_entry.memory_id]["sources"].append("entity")
|
||||||
|
|
||||||
|
# Outcome search
|
||||||
|
if query.use_outcome and query.outcomes:
|
||||||
|
outcome_results = await self._indexer.outcome_index.search(
|
||||||
|
query=None,
|
||||||
|
limit=query.limit * 3,
|
||||||
|
outcomes=query.outcomes,
|
||||||
|
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||||
|
)
|
||||||
|
for outcome_entry in outcome_results:
|
||||||
|
if outcome_entry.memory_id not in candidates:
|
||||||
|
candidates[outcome_entry.memory_id] = {
|
||||||
|
"memory_type": outcome_entry.memory_type,
|
||||||
|
"sources": [],
|
||||||
|
}
|
||||||
|
candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome
|
||||||
|
candidates[outcome_entry.memory_id]["sources"].append("outcome")
|
||||||
|
|
||||||
|
# Score and rank candidates
|
||||||
|
scored_results: list[ScoredResult] = []
|
||||||
|
entity_total = len(query.entities) if query.entities else 1
|
||||||
|
|
||||||
|
for memory_id, data in candidates.items():
|
||||||
|
scored = self._scorer.score(
|
||||||
|
memory_id=memory_id,
|
||||||
|
memory_type=data["memory_type"],
|
||||||
|
vector_similarity=data.get("vector_similarity"),
|
||||||
|
timestamp=data.get("timestamp"),
|
||||||
|
entity_match_count=data.get("entity_match_count", 0),
|
||||||
|
entity_total=entity_total,
|
||||||
|
outcome=data.get("outcome"),
|
||||||
|
preferred_outcomes=query.outcomes,
|
||||||
|
)
|
||||||
|
scored.metadata["sources"] = data.get("sources", [])
|
||||||
|
|
||||||
|
# Filter by minimum relevance
|
||||||
|
if scored.relevance_score >= query.min_relevance:
|
||||||
|
scored_results.append(scored)
|
||||||
|
|
||||||
|
# Sort by relevance score
|
||||||
|
scored_results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||||
|
|
||||||
|
# Apply limit
|
||||||
|
final_results = scored_results[: query.limit]
|
||||||
|
|
||||||
|
# Cache results
|
||||||
|
if use_cache and self._cache and final_results:
|
||||||
|
self._cache.put(cache_key, final_results)
|
||||||
|
|
||||||
|
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Retrieved {len(final_results)} results from {len(candidates)} candidates "
|
||||||
|
f"in {latency:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
return RetrievalResult(
|
||||||
|
items=final_results,
|
||||||
|
total_count=len(candidates),
|
||||||
|
query=query.query_text or "",
|
||||||
|
retrieval_type="hybrid",
|
||||||
|
latency_ms=latency,
|
||||||
|
metadata={
|
||||||
|
"cache_hit": False,
|
||||||
|
"candidates_count": len(candidates),
|
||||||
|
"filtered_count": len(scored_results),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def retrieve_similar(
|
||||||
|
self,
|
||||||
|
embedding: list[float],
|
||||||
|
limit: int = 10,
|
||||||
|
min_similarity: float = 0.5,
|
||||||
|
memory_types: list[MemoryType] | None = None,
|
||||||
|
) -> RetrievalResult[ScoredResult]:
|
||||||
|
"""
|
||||||
|
Retrieve memories similar to a given embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: Query embedding
|
||||||
|
limit: Maximum results
|
||||||
|
min_similarity: Minimum similarity threshold
|
||||||
|
memory_types: Filter by memory types
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Retrieval result with scored items
|
||||||
|
"""
|
||||||
|
query = RetrievalQuery(
|
||||||
|
query_embedding=embedding,
|
||||||
|
limit=limit,
|
||||||
|
min_relevance=min_similarity,
|
||||||
|
memory_types=memory_types,
|
||||||
|
use_temporal=False,
|
||||||
|
use_entity=False,
|
||||||
|
use_outcome=False,
|
||||||
|
)
|
||||||
|
return await self.retrieve(query)
|
||||||
|
|
||||||
|
async def retrieve_recent(
|
||||||
|
self,
|
||||||
|
hours: float = 24,
|
||||||
|
limit: int = 10,
|
||||||
|
memory_types: list[MemoryType] | None = None,
|
||||||
|
) -> RetrievalResult[ScoredResult]:
|
||||||
|
"""
|
||||||
|
Retrieve recent memories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hours: Number of hours to look back
|
||||||
|
limit: Maximum results
|
||||||
|
memory_types: Filter by memory types
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Retrieval result with scored items
|
||||||
|
"""
|
||||||
|
query = RetrievalQuery(
|
||||||
|
recent_seconds=hours * 3600,
|
||||||
|
limit=limit,
|
||||||
|
memory_types=memory_types,
|
||||||
|
use_vector=False,
|
||||||
|
use_entity=False,
|
||||||
|
use_outcome=False,
|
||||||
|
)
|
||||||
|
return await self.retrieve(query)
|
||||||
|
|
||||||
|
async def retrieve_by_entity(
|
||||||
|
self,
|
||||||
|
entity_type: str,
|
||||||
|
entity_value: str,
|
||||||
|
limit: int = 10,
|
||||||
|
memory_types: list[MemoryType] | None = None,
|
||||||
|
) -> RetrievalResult[ScoredResult]:
|
||||||
|
"""
|
||||||
|
Retrieve memories by entity.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_type: Type of entity
|
||||||
|
entity_value: Entity value
|
||||||
|
limit: Maximum results
|
||||||
|
memory_types: Filter by memory types
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Retrieval result with scored items
|
||||||
|
"""
|
||||||
|
query = RetrievalQuery(
|
||||||
|
entities=[(entity_type, entity_value)],
|
||||||
|
limit=limit,
|
||||||
|
memory_types=memory_types,
|
||||||
|
use_vector=False,
|
||||||
|
use_temporal=False,
|
||||||
|
use_outcome=False,
|
||||||
|
)
|
||||||
|
return await self.retrieve(query)
|
||||||
|
|
||||||
|
async def retrieve_successful(
|
||||||
|
self,
|
||||||
|
limit: int = 10,
|
||||||
|
memory_types: list[MemoryType] | None = None,
|
||||||
|
) -> RetrievalResult[ScoredResult]:
|
||||||
|
"""
|
||||||
|
Retrieve successful memories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
limit: Maximum results
|
||||||
|
memory_types: Filter by memory types
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Retrieval result with scored items
|
||||||
|
"""
|
||||||
|
query = RetrievalQuery(
|
||||||
|
outcomes=[Outcome.SUCCESS],
|
||||||
|
limit=limit,
|
||||||
|
memory_types=memory_types,
|
||||||
|
use_vector=False,
|
||||||
|
use_temporal=False,
|
||||||
|
use_entity=False,
|
||||||
|
)
|
||||||
|
return await self.retrieve(query)
|
||||||
|
|
||||||
|
def invalidate_cache(self) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate all cached results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
if self._cache:
|
||||||
|
return self._cache.clear()
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def invalidate_cache_for_memory(self, memory_id: UUID) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate cache entries containing a specific memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_id: Memory ID to invalidate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
if self._cache:
|
||||||
|
return self._cache.invalidate_by_memory(memory_id)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get cache statistics."""
|
||||||
|
if self._cache:
|
||||||
|
return self._cache.get_stats()
|
||||||
|
return {"enabled": False}
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton retrieval engine instance
|
||||||
|
_engine: RetrievalEngine | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_retrieval_engine() -> RetrievalEngine:
|
||||||
|
"""Get the singleton retrieval engine instance."""
|
||||||
|
global _engine
|
||||||
|
if _engine is None:
|
||||||
|
_engine = RetrievalEngine()
|
||||||
|
return _engine
|
||||||
2
backend/tests/unit/services/memory/indexing/__init__.py
Normal file
2
backend/tests/unit/services/memory/indexing/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# tests/unit/services/memory/indexing/__init__.py
|
||||||
|
"""Unit tests for memory indexing."""
|
||||||
497
backend/tests/unit/services/memory/indexing/test_index.py
Normal file
497
backend/tests/unit/services/memory/indexing/test_index.py
Normal file
@@ -0,0 +1,497 @@
|
|||||||
|
# tests/unit/services/memory/indexing/test_index.py
|
||||||
|
"""Unit tests for memory indexing."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.memory.indexing.index import (
|
||||||
|
EntityIndex,
|
||||||
|
MemoryIndexer,
|
||||||
|
OutcomeIndex,
|
||||||
|
TemporalIndex,
|
||||||
|
VectorIndex,
|
||||||
|
get_memory_indexer,
|
||||||
|
)
|
||||||
|
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
"""Get current UTC time."""
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
def make_episode(
|
||||||
|
embedding: list[float] | None = None,
|
||||||
|
outcome: Outcome = Outcome.SUCCESS,
|
||||||
|
occurred_at: datetime | None = None,
|
||||||
|
) -> Episode:
|
||||||
|
"""Create a test episode."""
|
||||||
|
return Episode(
|
||||||
|
id=uuid4(),
|
||||||
|
project_id=uuid4(),
|
||||||
|
agent_instance_id=uuid4(),
|
||||||
|
agent_type_id=uuid4(),
|
||||||
|
session_id="test-session",
|
||||||
|
task_type="test_task",
|
||||||
|
task_description="Test task description",
|
||||||
|
actions=[{"action": "test"}],
|
||||||
|
context_summary="Test context",
|
||||||
|
outcome=outcome,
|
||||||
|
outcome_details="Test outcome",
|
||||||
|
duration_seconds=10.0,
|
||||||
|
tokens_used=100,
|
||||||
|
lessons_learned=["lesson1"],
|
||||||
|
importance_score=0.8,
|
||||||
|
embedding=embedding,
|
||||||
|
occurred_at=occurred_at or _utcnow(),
|
||||||
|
created_at=_utcnow(),
|
||||||
|
updated_at=_utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_fact(
|
||||||
|
embedding: list[float] | None = None,
|
||||||
|
subject: str = "test_subject",
|
||||||
|
predicate: str = "has_property",
|
||||||
|
obj: str = "test_value",
|
||||||
|
) -> Fact:
|
||||||
|
"""Create a test fact."""
|
||||||
|
return Fact(
|
||||||
|
id=uuid4(),
|
||||||
|
project_id=uuid4(),
|
||||||
|
subject=subject,
|
||||||
|
predicate=predicate,
|
||||||
|
object=obj,
|
||||||
|
confidence=0.9,
|
||||||
|
source_episode_ids=[uuid4()],
|
||||||
|
first_learned=_utcnow(),
|
||||||
|
last_reinforced=_utcnow(),
|
||||||
|
reinforcement_count=1,
|
||||||
|
embedding=embedding,
|
||||||
|
created_at=_utcnow(),
|
||||||
|
updated_at=_utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_procedure(
|
||||||
|
embedding: list[float] | None = None,
|
||||||
|
success_count: int = 8,
|
||||||
|
failure_count: int = 2,
|
||||||
|
) -> Procedure:
|
||||||
|
"""Create a test procedure."""
|
||||||
|
return Procedure(
|
||||||
|
id=uuid4(),
|
||||||
|
project_id=uuid4(),
|
||||||
|
agent_type_id=uuid4(),
|
||||||
|
name="test_procedure",
|
||||||
|
trigger_pattern="test.*",
|
||||||
|
steps=[{"step": 1, "action": "test"}],
|
||||||
|
success_count=success_count,
|
||||||
|
failure_count=failure_count,
|
||||||
|
last_used=_utcnow(),
|
||||||
|
embedding=embedding,
|
||||||
|
created_at=_utcnow(),
|
||||||
|
updated_at=_utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVectorIndex:
|
||||||
|
"""Tests for VectorIndex."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def index(self) -> VectorIndex[Episode]:
|
||||||
|
"""Create a vector index."""
|
||||||
|
return VectorIndex[Episode](dimension=4)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_item(self, index: VectorIndex[Episode]) -> None:
|
||||||
|
"""Test adding an item to the index."""
|
||||||
|
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
entry = await index.add(episode)
|
||||||
|
|
||||||
|
assert entry.memory_id == episode.id
|
||||||
|
assert entry.memory_type == MemoryType.EPISODIC
|
||||||
|
assert entry.dimension == 4
|
||||||
|
assert await index.count() == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_item(self, index: VectorIndex[Episode]) -> None:
|
||||||
|
"""Test removing an item from the index."""
|
||||||
|
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
await index.add(episode)
|
||||||
|
|
||||||
|
result = await index.remove(episode.id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert await index.count() == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_nonexistent(self, index: VectorIndex[Episode]) -> None:
|
||||||
|
"""Test removing a nonexistent item."""
|
||||||
|
result = await index.remove(uuid4())
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_similar(self, index: VectorIndex[Episode]) -> None:
|
||||||
|
"""Test searching for similar items."""
|
||||||
|
# Add items with different embeddings
|
||||||
|
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
|
||||||
|
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
await index.add(e1)
|
||||||
|
await index.add(e2)
|
||||||
|
await index.add(e3)
|
||||||
|
|
||||||
|
# Search for similar to first
|
||||||
|
results = await index.search([1.0, 0.0, 0.0, 0.0], limit=2)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
# First result should be most similar
|
||||||
|
assert results[0].memory_id == e1.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_min_similarity(self, index: VectorIndex[Episode]) -> None:
|
||||||
|
"""Test minimum similarity threshold."""
|
||||||
|
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) # Orthogonal
|
||||||
|
|
||||||
|
await index.add(e1)
|
||||||
|
await index.add(e2)
|
||||||
|
|
||||||
|
# Search with high threshold
|
||||||
|
results = await index.search([1.0, 0.0, 0.0, 0.0], min_similarity=0.9)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].memory_id == e1.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_empty_query(self, index: VectorIndex[Episode]) -> None:
|
||||||
|
"""Test search with empty query."""
|
||||||
|
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
await index.add(e1)
|
||||||
|
|
||||||
|
results = await index.search([], limit=10)
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear(self, index: VectorIndex[Episode]) -> None:
|
||||||
|
"""Test clearing the index."""
|
||||||
|
await index.add(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||||
|
await index.add(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
|
||||||
|
|
||||||
|
count = await index.clear()
|
||||||
|
|
||||||
|
assert count == 2
|
||||||
|
assert await index.count() == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestTemporalIndex:
|
||||||
|
"""Tests for TemporalIndex."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def index(self) -> TemporalIndex[Episode]:
|
||||||
|
"""Create a temporal index."""
|
||||||
|
return TemporalIndex[Episode]()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_item(self, index: TemporalIndex[Episode]) -> None:
|
||||||
|
"""Test adding an item."""
|
||||||
|
episode = make_episode()
|
||||||
|
entry = await index.add(episode)
|
||||||
|
|
||||||
|
assert entry.memory_id == episode.id
|
||||||
|
assert await index.count() == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_by_time_range(self, index: TemporalIndex[Episode]) -> None:
|
||||||
|
"""Test searching by time range."""
|
||||||
|
now = _utcnow()
|
||||||
|
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||||
|
recent = make_episode(occurred_at=now - timedelta(hours=1))
|
||||||
|
newest = make_episode(occurred_at=now)
|
||||||
|
|
||||||
|
await index.add(old)
|
||||||
|
await index.add(recent)
|
||||||
|
await index.add(newest)
|
||||||
|
|
||||||
|
# Search last hour
|
||||||
|
results = await index.search(
|
||||||
|
query=None,
|
||||||
|
start_time=now - timedelta(hours=1, minutes=30),
|
||||||
|
end_time=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_recent(self, index: TemporalIndex[Episode]) -> None:
|
||||||
|
"""Test searching for recent items."""
|
||||||
|
now = _utcnow()
|
||||||
|
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||||
|
recent = make_episode(occurred_at=now - timedelta(minutes=30))
|
||||||
|
|
||||||
|
await index.add(old)
|
||||||
|
await index.add(recent)
|
||||||
|
|
||||||
|
# Search last hour (3600 seconds)
|
||||||
|
results = await index.search(query=None, recent_seconds=3600)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].memory_id == recent.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_order(self, index: TemporalIndex[Episode]) -> None:
|
||||||
|
"""Test result ordering."""
|
||||||
|
now = _utcnow()
|
||||||
|
e1 = make_episode(occurred_at=now - timedelta(hours=2))
|
||||||
|
e2 = make_episode(occurred_at=now - timedelta(hours=1))
|
||||||
|
e3 = make_episode(occurred_at=now)
|
||||||
|
|
||||||
|
await index.add(e1)
|
||||||
|
await index.add(e2)
|
||||||
|
await index.add(e3)
|
||||||
|
|
||||||
|
# Descending order (newest first)
|
||||||
|
results_desc = await index.search(query=None, order="desc", limit=10)
|
||||||
|
assert results_desc[0].memory_id == e3.id
|
||||||
|
|
||||||
|
# Ascending order (oldest first)
|
||||||
|
results_asc = await index.search(query=None, order="asc", limit=10)
|
||||||
|
assert results_asc[0].memory_id == e1.id
|
||||||
|
|
||||||
|
|
||||||
|
class TestEntityIndex:
|
||||||
|
"""Tests for EntityIndex."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def index(self) -> EntityIndex[Fact]:
|
||||||
|
"""Create an entity index."""
|
||||||
|
return EntityIndex[Fact]()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_item(self, index: EntityIndex[Fact]) -> None:
|
||||||
|
"""Test adding an item."""
|
||||||
|
fact = make_fact(subject="user", obj="admin")
|
||||||
|
entry = await index.add(fact)
|
||||||
|
|
||||||
|
assert entry.memory_id == fact.id
|
||||||
|
assert await index.count() == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_by_entity(self, index: EntityIndex[Fact]) -> None:
|
||||||
|
"""Test searching by entity."""
|
||||||
|
f1 = make_fact(subject="user", obj="admin")
|
||||||
|
f2 = make_fact(subject="system", obj="config")
|
||||||
|
|
||||||
|
await index.add(f1)
|
||||||
|
await index.add(f2)
|
||||||
|
|
||||||
|
results = await index.search(
|
||||||
|
query=None,
|
||||||
|
entity_type="subject",
|
||||||
|
entity_value="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].memory_id == f1.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_multiple_entities(self, index: EntityIndex[Fact]) -> None:
|
||||||
|
"""Test searching with multiple entities."""
|
||||||
|
f1 = make_fact(subject="user", obj="admin")
|
||||||
|
f2 = make_fact(subject="user", obj="guest")
|
||||||
|
|
||||||
|
await index.add(f1)
|
||||||
|
await index.add(f2)
|
||||||
|
|
||||||
|
# Search for facts about "user" subject
|
||||||
|
results = await index.search(
|
||||||
|
query=None,
|
||||||
|
entities=[("subject", "user")],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_match_all(self, index: EntityIndex[Fact]) -> None:
|
||||||
|
"""Test matching all entities."""
|
||||||
|
f1 = make_fact(subject="user", obj="admin")
|
||||||
|
f2 = make_fact(subject="user", obj="guest")
|
||||||
|
|
||||||
|
await index.add(f1)
|
||||||
|
await index.add(f2)
|
||||||
|
|
||||||
|
# Search for user+admin (match all)
|
||||||
|
results = await index.search(
|
||||||
|
query=None,
|
||||||
|
entities=[("subject", "user"), ("object", "admin")],
|
||||||
|
match_all=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].memory_id == f1.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_entities(self, index: EntityIndex[Fact]) -> None:
|
||||||
|
"""Test getting entities for a memory."""
|
||||||
|
fact = make_fact(subject="user", obj="admin")
|
||||||
|
await index.add(fact)
|
||||||
|
|
||||||
|
entities = await index.get_entities(fact.id)
|
||||||
|
|
||||||
|
assert ("subject", "user") in entities
|
||||||
|
assert ("object", "admin") in entities
|
||||||
|
|
||||||
|
|
||||||
|
class TestOutcomeIndex:
|
||||||
|
"""Tests for OutcomeIndex."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def index(self) -> OutcomeIndex[Episode]:
|
||||||
|
"""Create an outcome index."""
|
||||||
|
return OutcomeIndex[Episode]()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_item(self, index: OutcomeIndex[Episode]) -> None:
|
||||||
|
"""Test adding an item."""
|
||||||
|
episode = make_episode(outcome=Outcome.SUCCESS)
|
||||||
|
entry = await index.add(episode)
|
||||||
|
|
||||||
|
assert entry.memory_id == episode.id
|
||||||
|
assert entry.outcome == Outcome.SUCCESS
|
||||||
|
assert await index.count() == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_by_outcome(self, index: OutcomeIndex[Episode]) -> None:
|
||||||
|
"""Test searching by outcome."""
|
||||||
|
success = make_episode(outcome=Outcome.SUCCESS)
|
||||||
|
failure = make_episode(outcome=Outcome.FAILURE)
|
||||||
|
|
||||||
|
await index.add(success)
|
||||||
|
await index.add(failure)
|
||||||
|
|
||||||
|
results = await index.search(query=None, outcome=Outcome.SUCCESS)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].memory_id == success.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_multiple_outcomes(self, index: OutcomeIndex[Episode]) -> None:
|
||||||
|
"""Test searching with multiple outcomes."""
|
||||||
|
success = make_episode(outcome=Outcome.SUCCESS)
|
||||||
|
partial = make_episode(outcome=Outcome.PARTIAL)
|
||||||
|
failure = make_episode(outcome=Outcome.FAILURE)
|
||||||
|
|
||||||
|
await index.add(success)
|
||||||
|
await index.add(partial)
|
||||||
|
await index.add(failure)
|
||||||
|
|
||||||
|
results = await index.search(
|
||||||
|
query=None,
|
||||||
|
outcomes=[Outcome.SUCCESS, Outcome.PARTIAL],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_outcome_stats(self, index: OutcomeIndex[Episode]) -> None:
|
||||||
|
"""Test getting outcome statistics."""
|
||||||
|
await index.add(make_episode(outcome=Outcome.SUCCESS))
|
||||||
|
await index.add(make_episode(outcome=Outcome.SUCCESS))
|
||||||
|
await index.add(make_episode(outcome=Outcome.FAILURE))
|
||||||
|
|
||||||
|
stats = await index.get_outcome_stats()
|
||||||
|
|
||||||
|
assert stats[Outcome.SUCCESS] == 2
|
||||||
|
assert stats[Outcome.FAILURE] == 1
|
||||||
|
assert stats[Outcome.PARTIAL] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryIndexer:
|
||||||
|
"""Tests for MemoryIndexer."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def indexer(self) -> MemoryIndexer:
|
||||||
|
"""Create a memory indexer."""
|
||||||
|
return MemoryIndexer()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_index_episode(self, indexer: MemoryIndexer) -> None:
|
||||||
|
"""Test indexing an episode."""
|
||||||
|
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
results = await indexer.index(episode)
|
||||||
|
|
||||||
|
assert "vector" in results
|
||||||
|
assert "temporal" in results
|
||||||
|
assert "entity" in results
|
||||||
|
assert "outcome" in results
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_index_fact(self, indexer: MemoryIndexer) -> None:
|
||||||
|
"""Test indexing a fact."""
|
||||||
|
fact = make_fact(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
results = await indexer.index(fact)
|
||||||
|
|
||||||
|
# Facts don't have outcomes
|
||||||
|
assert "vector" in results
|
||||||
|
assert "temporal" in results
|
||||||
|
assert "entity" in results
|
||||||
|
assert "outcome" not in results
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_from_all(self, indexer: MemoryIndexer) -> None:
|
||||||
|
"""Test removing from all indices."""
|
||||||
|
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
await indexer.index(episode)
|
||||||
|
|
||||||
|
results = await indexer.remove(episode.id)
|
||||||
|
|
||||||
|
assert results["vector"] is True
|
||||||
|
assert results["temporal"] is True
|
||||||
|
assert results["entity"] is True
|
||||||
|
assert results["outcome"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_all(self, indexer: MemoryIndexer) -> None:
|
||||||
|
"""Test clearing all indices."""
|
||||||
|
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||||
|
await indexer.index(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
|
||||||
|
|
||||||
|
counts = await indexer.clear_all()
|
||||||
|
|
||||||
|
assert counts["vector"] == 2
|
||||||
|
assert counts["temporal"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats(self, indexer: MemoryIndexer) -> None:
|
||||||
|
"""Test getting index statistics."""
|
||||||
|
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||||
|
|
||||||
|
stats = await indexer.get_stats()
|
||||||
|
|
||||||
|
assert stats["vector"] == 1
|
||||||
|
assert stats["temporal"] == 1
|
||||||
|
assert stats["entity"] == 1
|
||||||
|
assert stats["outcome"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetMemoryIndexer:
|
||||||
|
"""Tests for singleton getter."""
|
||||||
|
|
||||||
|
def test_returns_instance(self) -> None:
|
||||||
|
"""Test that getter returns instance."""
|
||||||
|
indexer = get_memory_indexer()
|
||||||
|
assert indexer is not None
|
||||||
|
assert isinstance(indexer, MemoryIndexer)
|
||||||
|
|
||||||
|
def test_returns_same_instance(self) -> None:
|
||||||
|
"""Test that getter returns same instance."""
|
||||||
|
indexer1 = get_memory_indexer()
|
||||||
|
indexer2 = get_memory_indexer()
|
||||||
|
assert indexer1 is indexer2
|
||||||
450
backend/tests/unit/services/memory/indexing/test_retrieval.py
Normal file
450
backend/tests/unit/services/memory/indexing/test_retrieval.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
# tests/unit/services/memory/indexing/test_retrieval.py
|
||||||
|
"""Unit tests for memory retrieval."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.memory.indexing.index import MemoryIndexer
|
||||||
|
from app.services.memory.indexing.retrieval import (
|
||||||
|
RelevanceScorer,
|
||||||
|
RetrievalCache,
|
||||||
|
RetrievalEngine,
|
||||||
|
RetrievalQuery,
|
||||||
|
ScoredResult,
|
||||||
|
get_retrieval_engine,
|
||||||
|
)
|
||||||
|
from app.services.memory.types import Episode, MemoryType, Outcome
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
"""Get current UTC time."""
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
def make_episode(
|
||||||
|
embedding: list[float] | None = None,
|
||||||
|
outcome: Outcome = Outcome.SUCCESS,
|
||||||
|
occurred_at: datetime | None = None,
|
||||||
|
task_type: str = "test_task",
|
||||||
|
) -> Episode:
|
||||||
|
"""Create a test episode."""
|
||||||
|
return Episode(
|
||||||
|
id=uuid4(),
|
||||||
|
project_id=uuid4(),
|
||||||
|
agent_instance_id=uuid4(),
|
||||||
|
agent_type_id=uuid4(),
|
||||||
|
session_id="test-session",
|
||||||
|
task_type=task_type,
|
||||||
|
task_description="Test task description",
|
||||||
|
actions=[{"action": "test"}],
|
||||||
|
context_summary="Test context",
|
||||||
|
outcome=outcome,
|
||||||
|
outcome_details="Test outcome",
|
||||||
|
duration_seconds=10.0,
|
||||||
|
tokens_used=100,
|
||||||
|
lessons_learned=["lesson1"],
|
||||||
|
importance_score=0.8,
|
||||||
|
embedding=embedding,
|
||||||
|
occurred_at=occurred_at or _utcnow(),
|
||||||
|
created_at=_utcnow(),
|
||||||
|
updated_at=_utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrievalQuery:
|
||||||
|
"""Tests for RetrievalQuery."""
|
||||||
|
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
"""Test default query values."""
|
||||||
|
query = RetrievalQuery()
|
||||||
|
|
||||||
|
assert query.query_text is None
|
||||||
|
assert query.limit == 10
|
||||||
|
assert query.min_relevance == 0.0
|
||||||
|
assert query.use_vector is True
|
||||||
|
assert query.use_temporal is True
|
||||||
|
|
||||||
|
def test_cache_key_generation(self) -> None:
|
||||||
|
"""Test cache key generation."""
|
||||||
|
query1 = RetrievalQuery(query_text="test", limit=10)
|
||||||
|
query2 = RetrievalQuery(query_text="test", limit=10)
|
||||||
|
query3 = RetrievalQuery(query_text="different", limit=10)
|
||||||
|
|
||||||
|
# Same queries should have same key
|
||||||
|
assert query1.to_cache_key() == query2.to_cache_key()
|
||||||
|
# Different queries should have different keys
|
||||||
|
assert query1.to_cache_key() != query3.to_cache_key()
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoredResult:
|
||||||
|
"""Tests for ScoredResult."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test creating a scored result."""
|
||||||
|
result = ScoredResult(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
relevance_score=0.85,
|
||||||
|
score_breakdown={"vector": 0.9, "recency": 0.8},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.relevance_score == 0.85
|
||||||
|
assert result.score_breakdown["vector"] == 0.9
|
||||||
|
|
||||||
|
|
||||||
|
class TestRelevanceScorer:
|
||||||
|
"""Tests for RelevanceScorer."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def scorer(self) -> RelevanceScorer:
|
||||||
|
"""Create a relevance scorer."""
|
||||||
|
return RelevanceScorer()
|
||||||
|
|
||||||
|
def test_score_with_vector(self, scorer: RelevanceScorer) -> None:
|
||||||
|
"""Test scoring with vector similarity."""
|
||||||
|
result = scorer.score(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
vector_similarity=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.relevance_score > 0
|
||||||
|
assert result.score_breakdown["vector"] == 0.9
|
||||||
|
|
||||||
|
def test_score_with_recency(self, scorer: RelevanceScorer) -> None:
|
||||||
|
"""Test scoring with recency."""
|
||||||
|
recent_result = scorer.score(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
timestamp=_utcnow(),
|
||||||
|
)
|
||||||
|
|
||||||
|
old_result = scorer.score(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
timestamp=_utcnow() - timedelta(days=7),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recent should have higher recency score
|
||||||
|
assert (
|
||||||
|
recent_result.score_breakdown["recency"]
|
||||||
|
> old_result.score_breakdown["recency"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_score_with_outcome_preference(self, scorer: RelevanceScorer) -> None:
|
||||||
|
"""Test scoring with outcome preference."""
|
||||||
|
success_result = scorer.score(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
outcome=Outcome.SUCCESS,
|
||||||
|
preferred_outcomes=[Outcome.SUCCESS],
|
||||||
|
)
|
||||||
|
|
||||||
|
failure_result = scorer.score(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
outcome=Outcome.FAILURE,
|
||||||
|
preferred_outcomes=[Outcome.SUCCESS],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert success_result.score_breakdown["outcome"] == 1.0
|
||||||
|
assert failure_result.score_breakdown["outcome"] == 0.0
|
||||||
|
|
||||||
|
def test_score_with_entity_match(self, scorer: RelevanceScorer) -> None:
|
||||||
|
"""Test scoring with entity matches."""
|
||||||
|
full_match = scorer.score(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
entity_match_count=3,
|
||||||
|
entity_total=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
partial_match = scorer.score(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
entity_match_count=1,
|
||||||
|
entity_total=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
full_match.score_breakdown["entity"]
|
||||||
|
> partial_match.score_breakdown["entity"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrievalCache:
|
||||||
|
"""Tests for RetrievalCache."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def cache(self) -> RetrievalCache:
|
||||||
|
"""Create a retrieval cache."""
|
||||||
|
return RetrievalCache(max_entries=10, default_ttl_seconds=60)
|
||||||
|
|
||||||
|
def test_put_and_get(self, cache: RetrievalCache) -> None:
|
||||||
|
"""Test putting and getting from cache."""
|
||||||
|
results = [
|
||||||
|
ScoredResult(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
cache.put("test_key", results)
|
||||||
|
cached = cache.get("test_key")
|
||||||
|
|
||||||
|
assert cached is not None
|
||||||
|
assert len(cached) == 1
|
||||||
|
|
||||||
|
def test_get_nonexistent(self, cache: RetrievalCache) -> None:
|
||||||
|
"""Test getting nonexistent entry."""
|
||||||
|
result = cache.get("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_lru_eviction(self) -> None:
|
||||||
|
"""Test LRU eviction when at capacity."""
|
||||||
|
cache = RetrievalCache(max_entries=2, default_ttl_seconds=60)
|
||||||
|
|
||||||
|
results = [
|
||||||
|
ScoredResult(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
cache.put("key1", results)
|
||||||
|
cache.put("key2", results)
|
||||||
|
cache.put("key3", results) # Should evict key1
|
||||||
|
|
||||||
|
assert cache.get("key1") is None
|
||||||
|
assert cache.get("key2") is not None
|
||||||
|
assert cache.get("key3") is not None
|
||||||
|
|
||||||
|
def test_invalidate(self, cache: RetrievalCache) -> None:
|
||||||
|
"""Test invalidating a cache entry."""
|
||||||
|
results = [
|
||||||
|
ScoredResult(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
cache.put("test_key", results)
|
||||||
|
removed = cache.invalidate("test_key")
|
||||||
|
|
||||||
|
assert removed is True
|
||||||
|
assert cache.get("test_key") is None
|
||||||
|
|
||||||
|
def test_invalidate_by_memory(self, cache: RetrievalCache) -> None:
|
||||||
|
"""Test invalidating by memory ID."""
|
||||||
|
memory_id = uuid4()
|
||||||
|
results = [
|
||||||
|
ScoredResult(
|
||||||
|
memory_id=memory_id,
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
cache.put("key1", results)
|
||||||
|
cache.put("key2", results)
|
||||||
|
|
||||||
|
count = cache.invalidate_by_memory(memory_id)
|
||||||
|
|
||||||
|
assert count == 2
|
||||||
|
assert cache.get("key1") is None
|
||||||
|
assert cache.get("key2") is None
|
||||||
|
|
||||||
|
def test_clear(self, cache: RetrievalCache) -> None:
|
||||||
|
"""Test clearing the cache."""
|
||||||
|
results = [
|
||||||
|
ScoredResult(
|
||||||
|
memory_id=uuid4(),
|
||||||
|
memory_type=MemoryType.EPISODIC,
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
cache.put("key1", results)
|
||||||
|
cache.put("key2", results)
|
||||||
|
|
||||||
|
count = cache.clear()
|
||||||
|
|
||||||
|
assert count == 2
|
||||||
|
assert cache.get("key1") is None
|
||||||
|
|
||||||
|
def test_get_stats(self, cache: RetrievalCache) -> None:
|
||||||
|
"""Test getting cache statistics."""
|
||||||
|
stats = cache.get_stats()
|
||||||
|
|
||||||
|
assert "total_entries" in stats
|
||||||
|
assert "max_entries" in stats
|
||||||
|
assert stats["max_entries"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrievalEngine:
|
||||||
|
"""Tests for RetrievalEngine."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def indexer(self) -> MemoryIndexer:
|
||||||
|
"""Create a memory indexer."""
|
||||||
|
return MemoryIndexer()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def engine(self, indexer: MemoryIndexer) -> RetrievalEngine:
|
||||||
|
"""Create a retrieval engine."""
|
||||||
|
return RetrievalEngine(indexer=indexer, enable_cache=True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_by_vector(
|
||||||
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||||
|
) -> None:
|
||||||
|
"""Test retrieval by vector similarity."""
|
||||||
|
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
|
||||||
|
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
await indexer.index(e1)
|
||||||
|
await indexer.index(e2)
|
||||||
|
await indexer.index(e3)
|
||||||
|
|
||||||
|
query = RetrievalQuery(
|
||||||
|
query_embedding=[1.0, 0.0, 0.0, 0.0],
|
||||||
|
limit=2,
|
||||||
|
use_temporal=False,
|
||||||
|
use_entity=False,
|
||||||
|
use_outcome=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await engine.retrieve(query)
|
||||||
|
|
||||||
|
assert len(result.items) > 0
|
||||||
|
assert result.retrieval_type == "hybrid"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_recent(
|
||||||
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||||
|
) -> None:
|
||||||
|
"""Test retrieval of recent items."""
|
||||||
|
now = _utcnow()
|
||||||
|
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||||
|
recent = make_episode(occurred_at=now - timedelta(minutes=30))
|
||||||
|
|
||||||
|
await indexer.index(old)
|
||||||
|
await indexer.index(recent)
|
||||||
|
|
||||||
|
result = await engine.retrieve_recent(hours=1)
|
||||||
|
|
||||||
|
assert len(result.items) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_by_entity(
|
||||||
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||||
|
) -> None:
|
||||||
|
"""Test retrieval by entity."""
|
||||||
|
e1 = make_episode(task_type="deploy")
|
||||||
|
e2 = make_episode(task_type="test")
|
||||||
|
|
||||||
|
await indexer.index(e1)
|
||||||
|
await indexer.index(e2)
|
||||||
|
|
||||||
|
result = await engine.retrieve_by_entity("task_type", "deploy")
|
||||||
|
|
||||||
|
assert len(result.items) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_successful(
|
||||||
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||||
|
) -> None:
|
||||||
|
"""Test retrieval of successful items."""
|
||||||
|
success = make_episode(outcome=Outcome.SUCCESS)
|
||||||
|
failure = make_episode(outcome=Outcome.FAILURE)
|
||||||
|
|
||||||
|
await indexer.index(success)
|
||||||
|
await indexer.index(failure)
|
||||||
|
|
||||||
|
result = await engine.retrieve_successful()
|
||||||
|
|
||||||
|
assert len(result.items) == 1
|
||||||
|
# Check outcome index was used
|
||||||
|
assert result.items[0].memory_id == success.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_with_cache(
|
||||||
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||||
|
) -> None:
|
||||||
|
"""Test that retrieval uses cache."""
|
||||||
|
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
await indexer.index(episode)
|
||||||
|
|
||||||
|
query = RetrievalQuery(
|
||||||
|
query_embedding=[1.0, 0.0, 0.0, 0.0],
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First retrieval
|
||||||
|
result1 = await engine.retrieve(query)
|
||||||
|
assert result1.metadata.get("cache_hit") is False
|
||||||
|
|
||||||
|
# Second retrieval should be cached
|
||||||
|
result2 = await engine.retrieve(query)
|
||||||
|
assert result2.metadata.get("cache_hit") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalidate_cache(
|
||||||
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||||
|
) -> None:
|
||||||
|
"""Test cache invalidation."""
|
||||||
|
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
await indexer.index(episode)
|
||||||
|
|
||||||
|
query = RetrievalQuery(query_embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
await engine.retrieve(query)
|
||||||
|
|
||||||
|
count = engine.invalidate_cache()
|
||||||
|
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_similar(
|
||||||
|
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||||
|
) -> None:
|
||||||
|
"""Test retrieve_similar convenience method."""
|
||||||
|
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||||
|
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||||
|
|
||||||
|
await indexer.index(e1)
|
||||||
|
await indexer.index(e2)
|
||||||
|
|
||||||
|
result = await engine.retrieve_similar(
|
||||||
|
embedding=[1.0, 0.0, 0.0, 0.0],
|
||||||
|
limit=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.items) == 1
|
||||||
|
|
||||||
|
def test_get_cache_stats(self, engine: RetrievalEngine) -> None:
|
||||||
|
"""Test getting cache statistics."""
|
||||||
|
stats = engine.get_cache_stats()
|
||||||
|
|
||||||
|
assert "total_entries" in stats
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetRetrievalEngine:
|
||||||
|
"""Tests for singleton getter."""
|
||||||
|
|
||||||
|
def test_returns_instance(self) -> None:
|
||||||
|
"""Test that getter returns instance."""
|
||||||
|
engine = get_retrieval_engine()
|
||||||
|
assert engine is not None
|
||||||
|
assert isinstance(engine, RetrievalEngine)
|
||||||
|
|
||||||
|
def test_returns_same_instance(self) -> None:
|
||||||
|
"""Test that getter returns same instance."""
|
||||||
|
engine1 = get_retrieval_engine()
|
||||||
|
engine2 = get_retrieval_engine()
|
||||||
|
assert engine1 is engine2
|
||||||
Reference in New Issue
Block a user