forked from cardosofelipe/fast-next-template
feat(memory): implement memory indexing and retrieval engine (#94)
Add comprehensive indexing and retrieval system for memory search: - VectorIndex for semantic similarity search using cosine similarity - TemporalIndex for time-based queries with range and recency support - EntityIndex for entity-based lookups with multi-entity intersection - OutcomeIndex for success/failure filtering on episodes - MemoryIndexer as unified interface for all index types - RetrievalEngine with hybrid search combining all indices - RelevanceScorer for multi-signal relevance scoring - RetrievalCache for LRU caching of search results 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,56 @@
|
||||
# app/services/memory/indexing/__init__.py
|
||||
"""
|
||||
Memory Indexing
|
||||
Memory Indexing & Retrieval.
|
||||
|
||||
Vector embeddings and retrieval engine for memory search.
|
||||
Provides vector embeddings and multiple index types for efficient memory search:
|
||||
- Vector index for semantic similarity
|
||||
- Temporal index for time-based queries
|
||||
- Entity index for entity lookups
|
||||
- Outcome index for success/failure filtering
|
||||
"""
|
||||
|
||||
# Will be populated in #94
|
||||
from .index import (
|
||||
EntityIndex,
|
||||
EntityIndexEntry,
|
||||
IndexEntry,
|
||||
MemoryIndex,
|
||||
MemoryIndexer,
|
||||
OutcomeIndex,
|
||||
OutcomeIndexEntry,
|
||||
TemporalIndex,
|
||||
TemporalIndexEntry,
|
||||
VectorIndex,
|
||||
VectorIndexEntry,
|
||||
get_memory_indexer,
|
||||
)
|
||||
from .retrieval import (
|
||||
CacheEntry,
|
||||
RelevanceScorer,
|
||||
RetrievalCache,
|
||||
RetrievalEngine,
|
||||
RetrievalQuery,
|
||||
ScoredResult,
|
||||
get_retrieval_engine,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CacheEntry",
|
||||
"EntityIndex",
|
||||
"EntityIndexEntry",
|
||||
"IndexEntry",
|
||||
"MemoryIndex",
|
||||
"MemoryIndexer",
|
||||
"OutcomeIndex",
|
||||
"OutcomeIndexEntry",
|
||||
"RelevanceScorer",
|
||||
"RetrievalCache",
|
||||
"RetrievalEngine",
|
||||
"RetrievalQuery",
|
||||
"ScoredResult",
|
||||
"TemporalIndex",
|
||||
"TemporalIndexEntry",
|
||||
"VectorIndex",
|
||||
"VectorIndexEntry",
|
||||
"get_memory_indexer",
|
||||
"get_retrieval_engine",
|
||||
]
|
||||
|
||||
851
backend/app/services/memory/indexing/index.py
Normal file
851
backend/app/services/memory/indexing/index.py
Normal file
@@ -0,0 +1,851 @@
|
||||
# app/services/memory/indexing/index.py
|
||||
"""
|
||||
Memory Indexing.
|
||||
|
||||
Provides multiple indexing strategies for efficient memory retrieval:
|
||||
- Vector embeddings for semantic search
|
||||
- Temporal index for time-based queries
|
||||
- Entity index for entity-based lookups
|
||||
- Outcome index for success/failure filtering
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", Episode, Fact, Procedure)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexEntry:
|
||||
"""A single entry in an index."""
|
||||
|
||||
memory_id: UUID
|
||||
memory_type: MemoryType
|
||||
indexed_at: datetime = field(default_factory=_utcnow)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorIndexEntry(IndexEntry):
|
||||
"""An entry with vector embedding."""
|
||||
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
dimension: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set dimension from embedding."""
|
||||
if self.embedding:
|
||||
self.dimension = len(self.embedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemporalIndexEntry(IndexEntry):
|
||||
"""An entry indexed by time."""
|
||||
|
||||
timestamp: datetime = field(default_factory=_utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityIndexEntry(IndexEntry):
|
||||
"""An entry indexed by entity."""
|
||||
|
||||
entity_type: str = ""
|
||||
entity_value: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutcomeIndexEntry(IndexEntry):
|
||||
"""An entry indexed by outcome."""
|
||||
|
||||
outcome: Outcome = Outcome.SUCCESS
|
||||
|
||||
|
||||
class MemoryIndex[T](ABC):
|
||||
"""Abstract base class for memory indices."""
|
||||
|
||||
@abstractmethod
|
||||
async def add(self, item: T) -> IndexEntry:
|
||||
"""Add an item to the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> list[IndexEntry]:
|
||||
"""Search the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
...
|
||||
|
||||
|
||||
class VectorIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Vector-based index using embeddings for semantic similarity search.
|
||||
|
||||
Uses cosine similarity for matching.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int = 1536) -> None:
|
||||
"""
|
||||
Initialize the vector index.
|
||||
|
||||
Args:
|
||||
dimension: Embedding dimension (default 1536 for OpenAI)
|
||||
"""
|
||||
self._dimension = dimension
|
||||
self._entries: dict[UUID, VectorIndexEntry] = {}
|
||||
logger.info(f"Initialized VectorIndex with dimension={dimension}")
|
||||
|
||||
async def add(self, item: T) -> VectorIndexEntry:
|
||||
"""
|
||||
Add an item to the vector index.
|
||||
|
||||
Args:
|
||||
item: Memory item with embedding
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
embedding = getattr(item, "embedding", None) or []
|
||||
|
||||
entry = VectorIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
embedding=embedding,
|
||||
dimension=len(embedding),
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
logger.debug(f"Added {item.id} to vector index")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the vector index."""
|
||||
if memory_id in self._entries:
|
||||
del self._entries[memory_id]
|
||||
logger.debug(f"Removed {memory_id} from vector index")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
min_similarity: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> list[VectorIndexEntry]:
|
||||
"""
|
||||
Search for similar items using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Query embedding vector
|
||||
limit: Maximum results to return
|
||||
min_similarity: Minimum similarity threshold (0-1)
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries sorted by similarity
|
||||
"""
|
||||
if not isinstance(query, list) or not query:
|
||||
return []
|
||||
|
||||
results: list[tuple[float, VectorIndexEntry]] = []
|
||||
|
||||
for entry in self._entries.values():
|
||||
if not entry.embedding:
|
||||
continue
|
||||
|
||||
similarity = self._cosine_similarity(query, entry.embedding)
|
||||
if similarity >= min_similarity:
|
||||
results.append((similarity, entry))
|
||||
|
||||
# Sort by similarity descending
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
if memory_type:
|
||||
results = [(s, e) for s, e in results if e.memory_type == memory_type]
|
||||
|
||||
# Store similarity in metadata for the returned entries
|
||||
output = []
|
||||
for similarity, entry in results[:limit]:
|
||||
entry.metadata["similarity"] = similarity
|
||||
output.append(entry)
|
||||
|
||||
logger.debug(f"Vector search returned {len(output)} results")
|
||||
return output
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
logger.info(f"Cleared {count} entries from vector index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
if len(a) != len(b) or len(a) == 0:
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(x * y for x, y in zip(a, b, strict=True))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm_a * norm_b)
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class TemporalIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Time-based index for efficient temporal queries.
|
||||
|
||||
Supports:
|
||||
- Range queries (between timestamps)
|
||||
- Recent items (within last N seconds/hours/days)
|
||||
- Oldest/newest sorting
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the temporal index."""
|
||||
self._entries: dict[UUID, TemporalIndexEntry] = {}
|
||||
# Sorted list for efficient range queries
|
||||
self._sorted_entries: list[tuple[datetime, UUID]] = []
|
||||
logger.info("Initialized TemporalIndex")
|
||||
|
||||
async def add(self, item: T) -> TemporalIndexEntry:
|
||||
"""
|
||||
Add an item to the temporal index.
|
||||
|
||||
Args:
|
||||
item: Memory item with timestamp
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
# Get timestamp from various possible fields
|
||||
timestamp = self._get_timestamp(item)
|
||||
|
||||
entry = TemporalIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
self._insert_sorted(timestamp, item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to temporal index at {timestamp}")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the temporal index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
self._entries.pop(memory_id)
|
||||
self._sorted_entries = [
|
||||
(ts, mid) for ts, mid in self._sorted_entries if mid != memory_id
|
||||
]
|
||||
|
||||
logger.debug(f"Removed {memory_id} from temporal index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
recent_seconds: float | None = None,
|
||||
order: str = "desc",
|
||||
**kwargs: Any,
|
||||
) -> list[TemporalIndexEntry]:
|
||||
"""
|
||||
Search for items by time.
|
||||
|
||||
Args:
|
||||
query: Ignored for temporal search
|
||||
limit: Maximum results to return
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
recent_seconds: Get items from last N seconds
|
||||
order: Sort order ("asc" or "desc")
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries sorted by time
|
||||
"""
|
||||
if recent_seconds is not None:
|
||||
start_time = _utcnow() - timedelta(seconds=recent_seconds)
|
||||
end_time = _utcnow()
|
||||
|
||||
# Filter by time range
|
||||
results: list[TemporalIndexEntry] = []
|
||||
for entry in self._entries.values():
|
||||
if start_time and entry.timestamp < start_time:
|
||||
continue
|
||||
if end_time and entry.timestamp > end_time:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
if memory_type:
|
||||
results = [e for e in results if e.memory_type == memory_type]
|
||||
|
||||
# Sort by timestamp
|
||||
results.sort(key=lambda e: e.timestamp, reverse=(order == "desc"))
|
||||
|
||||
logger.debug(f"Temporal search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._sorted_entries.clear()
|
||||
logger.info(f"Cleared {count} entries from temporal index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None:
|
||||
"""Insert entry maintaining sorted order."""
|
||||
# Binary search insert for efficiency
|
||||
low, high = 0, len(self._sorted_entries)
|
||||
while low < high:
|
||||
mid = (low + high) // 2
|
||||
if self._sorted_entries[mid][0] < timestamp:
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid
|
||||
self._sorted_entries.insert(low, (timestamp, memory_id))
|
||||
|
||||
def _get_timestamp(self, item: T) -> datetime:
|
||||
"""Get the relevant timestamp for an item."""
|
||||
if hasattr(item, "occurred_at"):
|
||||
return item.occurred_at
|
||||
if hasattr(item, "first_learned"):
|
||||
return item.first_learned
|
||||
if hasattr(item, "last_used") and item.last_used:
|
||||
return item.last_used
|
||||
if hasattr(item, "created_at"):
|
||||
return item.created_at
|
||||
return _utcnow()
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class EntityIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Entity-based index for lookups by entities mentioned in memories.
|
||||
|
||||
Supports:
|
||||
- Single entity lookup
|
||||
- Multi-entity intersection
|
||||
- Entity type filtering
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the entity index."""
|
||||
# Main storage
|
||||
self._entries: dict[UUID, EntityIndexEntry] = {}
|
||||
# Inverted index: entity -> set of memory IDs
|
||||
self._entity_to_memories: dict[str, set[UUID]] = {}
|
||||
# Memory to entities mapping
|
||||
self._memory_to_entities: dict[UUID, set[str]] = {}
|
||||
logger.info("Initialized EntityIndex")
|
||||
|
||||
async def add(self, item: T) -> EntityIndexEntry:
|
||||
"""
|
||||
Add an item to the entity index.
|
||||
|
||||
Args:
|
||||
item: Memory item with entity information
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
entities = self._extract_entities(item)
|
||||
|
||||
# Create entry for the primary entity (or first one)
|
||||
primary_entity = entities[0] if entities else ("unknown", "unknown")
|
||||
|
||||
entry = EntityIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
entity_type=primary_entity[0],
|
||||
entity_value=primary_entity[1],
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
|
||||
# Update inverted indices
|
||||
entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities}
|
||||
self._memory_to_entities[item.id] = entity_keys
|
||||
|
||||
for entity_key in entity_keys:
|
||||
if entity_key not in self._entity_to_memories:
|
||||
self._entity_to_memories[entity_key] = set()
|
||||
self._entity_to_memories[entity_key].add(item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to entity index with {len(entities)} entities")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the entity index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
# Remove from inverted index
|
||||
if memory_id in self._memory_to_entities:
|
||||
for entity_key in self._memory_to_entities[memory_id]:
|
||||
if entity_key in self._entity_to_memories:
|
||||
self._entity_to_memories[entity_key].discard(memory_id)
|
||||
if not self._entity_to_memories[entity_key]:
|
||||
del self._entity_to_memories[entity_key]
|
||||
del self._memory_to_entities[memory_id]
|
||||
|
||||
del self._entries[memory_id]
|
||||
logger.debug(f"Removed {memory_id} from entity index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
entity_type: str | None = None,
|
||||
entity_value: str | None = None,
|
||||
entities: list[tuple[str, str]] | None = None,
|
||||
match_all: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[EntityIndexEntry]:
|
||||
"""
|
||||
Search for items by entity.
|
||||
|
||||
Args:
|
||||
query: Entity value to search (if entity_type not specified)
|
||||
limit: Maximum results to return
|
||||
entity_type: Type of entity to filter
|
||||
entity_value: Specific entity value
|
||||
entities: List of (type, value) tuples to match
|
||||
match_all: If True, require all entities to match
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries
|
||||
"""
|
||||
matching_ids: set[UUID] | None = None
|
||||
|
||||
# Handle single entity query
|
||||
if entity_type and entity_value:
|
||||
entities = [(entity_type, entity_value)]
|
||||
elif entity_value is None and isinstance(query, str):
|
||||
# Search across all entity types
|
||||
entity_value = query
|
||||
|
||||
if entities:
|
||||
for etype, evalue in entities:
|
||||
entity_key = f"{etype}:{evalue}"
|
||||
if entity_key in self._entity_to_memories:
|
||||
ids = self._entity_to_memories[entity_key]
|
||||
if matching_ids is None:
|
||||
matching_ids = ids.copy()
|
||||
elif match_all:
|
||||
matching_ids &= ids
|
||||
else:
|
||||
matching_ids |= ids
|
||||
elif match_all:
|
||||
# Required entity not found
|
||||
matching_ids = set()
|
||||
break
|
||||
elif entity_value:
|
||||
# Search for value across all types
|
||||
matching_ids = set()
|
||||
for entity_key, ids in self._entity_to_memories.items():
|
||||
if entity_value.lower() in entity_key.lower():
|
||||
matching_ids |= ids
|
||||
|
||||
if matching_ids is None:
|
||||
matching_ids = set(self._entries.keys())
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
results = []
|
||||
for mid in matching_ids:
|
||||
if mid in self._entries:
|
||||
entry = self._entries[mid]
|
||||
if memory_type and entry.memory_type != memory_type:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
logger.debug(f"Entity search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._entity_to_memories.clear()
|
||||
self._memory_to_entities.clear()
|
||||
logger.info(f"Cleared {count} entries from entity index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]:
|
||||
"""Get all entities for a memory item."""
|
||||
if memory_id not in self._memory_to_entities:
|
||||
return []
|
||||
|
||||
entities = []
|
||||
for entity_key in self._memory_to_entities[memory_id]:
|
||||
if ":" in entity_key:
|
||||
etype, evalue = entity_key.split(":", 1)
|
||||
entities.append((etype, evalue))
|
||||
return entities
|
||||
|
||||
def _extract_entities(self, item: T) -> list[tuple[str, str]]:
|
||||
"""Extract entities from a memory item."""
|
||||
entities: list[tuple[str, str]] = []
|
||||
|
||||
if isinstance(item, Episode):
|
||||
# Extract from task type and context
|
||||
entities.append(("task_type", item.task_type))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
if item.agent_instance_id:
|
||||
entities.append(("agent_instance", str(item.agent_instance_id)))
|
||||
if item.agent_type_id:
|
||||
entities.append(("agent_type", str(item.agent_type_id)))
|
||||
|
||||
elif isinstance(item, Fact):
|
||||
# Subject and object are entities
|
||||
entities.append(("subject", item.subject))
|
||||
entities.append(("object", item.object))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
|
||||
elif isinstance(item, Procedure):
|
||||
entities.append(("procedure", item.name))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
if item.agent_type_id:
|
||||
entities.append(("agent_type", str(item.agent_type_id)))
|
||||
|
||||
return entities
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class OutcomeIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Outcome-based index for filtering by success/failure.
|
||||
|
||||
Primarily used for episodes and procedures.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the outcome index."""
|
||||
self._entries: dict[UUID, OutcomeIndexEntry] = {}
|
||||
# Inverted index by outcome
|
||||
self._outcome_to_memories: dict[Outcome, set[UUID]] = {
|
||||
Outcome.SUCCESS: set(),
|
||||
Outcome.FAILURE: set(),
|
||||
Outcome.PARTIAL: set(),
|
||||
}
|
||||
logger.info("Initialized OutcomeIndex")
|
||||
|
||||
async def add(self, item: T) -> OutcomeIndexEntry:
|
||||
"""
|
||||
Add an item to the outcome index.
|
||||
|
||||
Args:
|
||||
item: Memory item with outcome information
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
outcome = self._get_outcome(item)
|
||||
|
||||
entry = OutcomeIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
outcome=outcome,
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
self._outcome_to_memories[outcome].add(item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to outcome index with {outcome.value}")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the outcome index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
entry = self._entries.pop(memory_id)
|
||||
self._outcome_to_memories[entry.outcome].discard(memory_id)
|
||||
|
||||
logger.debug(f"Removed {memory_id} from outcome index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
outcome: Outcome | None = None,
|
||||
outcomes: list[Outcome] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[OutcomeIndexEntry]:
|
||||
"""
|
||||
Search for items by outcome.
|
||||
|
||||
Args:
|
||||
query: Ignored for outcome search
|
||||
limit: Maximum results to return
|
||||
outcome: Single outcome to filter
|
||||
outcomes: Multiple outcomes to filter (OR)
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries
|
||||
"""
|
||||
if outcome:
|
||||
outcomes = [outcome]
|
||||
|
||||
if outcomes:
|
||||
matching_ids: set[UUID] = set()
|
||||
for o in outcomes:
|
||||
matching_ids |= self._outcome_to_memories.get(o, set())
|
||||
else:
|
||||
matching_ids = set(self._entries.keys())
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
results = []
|
||||
for mid in matching_ids:
|
||||
if mid in self._entries:
|
||||
entry = self._entries[mid]
|
||||
if memory_type and entry.memory_type != memory_type:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
logger.debug(f"Outcome search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
for outcome in self._outcome_to_memories:
|
||||
self._outcome_to_memories[outcome].clear()
|
||||
logger.info(f"Cleared {count} entries from outcome index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
async def get_outcome_stats(self) -> dict[Outcome, int]:
|
||||
"""Get statistics on outcomes."""
|
||||
return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()}
|
||||
|
||||
def _get_outcome(self, item: T) -> Outcome:
|
||||
"""Get the outcome for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return item.outcome
|
||||
elif isinstance(item, Procedure):
|
||||
# Derive from success rate
|
||||
if item.success_rate >= 0.8:
|
||||
return Outcome.SUCCESS
|
||||
elif item.success_rate <= 0.2:
|
||||
return Outcome.FAILURE
|
||||
return Outcome.PARTIAL
|
||||
return Outcome.SUCCESS
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryIndexer:
|
||||
"""
|
||||
Unified indexer that manages all index types.
|
||||
|
||||
Provides a single interface for indexing and searching across
|
||||
multiple index types.
|
||||
"""
|
||||
|
||||
vector_index: VectorIndex[Any] = field(default_factory=VectorIndex)
|
||||
temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex)
|
||||
entity_index: EntityIndex[Any] = field(default_factory=EntityIndex)
|
||||
outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex)
|
||||
|
||||
async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]:
|
||||
"""
|
||||
Index an item across all applicable indices.
|
||||
|
||||
Args:
|
||||
item: Memory item to index
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to entry
|
||||
"""
|
||||
results: dict[str, IndexEntry] = {}
|
||||
|
||||
# Vector index (if embedding present)
|
||||
if getattr(item, "embedding", None):
|
||||
results["vector"] = await self.vector_index.add(item)
|
||||
|
||||
# Temporal index
|
||||
results["temporal"] = await self.temporal_index.add(item)
|
||||
|
||||
# Entity index
|
||||
results["entity"] = await self.entity_index.add(item)
|
||||
|
||||
# Outcome index (for episodes and procedures)
|
||||
if isinstance(item, (Episode, Procedure)):
|
||||
results["outcome"] = await self.outcome_index.add(item)
|
||||
|
||||
logger.info(
|
||||
f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}"
|
||||
)
|
||||
return results
|
||||
|
||||
async def remove(self, memory_id: UUID) -> dict[str, bool]:
|
||||
"""
|
||||
Remove an item from all indices.
|
||||
|
||||
Args:
|
||||
memory_id: ID of the memory to remove
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to removal success
|
||||
"""
|
||||
results = {
|
||||
"vector": await self.vector_index.remove(memory_id),
|
||||
"temporal": await self.temporal_index.remove(memory_id),
|
||||
"entity": await self.entity_index.remove(memory_id),
|
||||
"outcome": await self.outcome_index.remove(memory_id),
|
||||
}
|
||||
|
||||
removed_from = [k for k, v in results.items() if v]
|
||||
if removed_from:
|
||||
logger.info(f"Removed {memory_id} from indices: {removed_from}")
|
||||
|
||||
return results
|
||||
|
||||
async def clear_all(self) -> dict[str, int]:
|
||||
"""
|
||||
Clear all indices.
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to count cleared
|
||||
"""
|
||||
return {
|
||||
"vector": await self.vector_index.clear(),
|
||||
"temporal": await self.temporal_index.clear(),
|
||||
"entity": await self.entity_index.clear(),
|
||||
"outcome": await self.outcome_index.clear(),
|
||||
}
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics for all indices.
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to entry count
|
||||
"""
|
||||
return {
|
||||
"vector": await self.vector_index.count(),
|
||||
"temporal": await self.temporal_index.count(),
|
||||
"entity": await self.entity_index.count(),
|
||||
"outcome": await self.outcome_index.count(),
|
||||
}
|
||||
|
||||
|
||||
# Singleton indexer instance
|
||||
_indexer: MemoryIndexer | None = None
|
||||
|
||||
|
||||
def get_memory_indexer() -> MemoryIndexer:
|
||||
"""Get the singleton memory indexer instance."""
|
||||
global _indexer
|
||||
if _indexer is None:
|
||||
_indexer = MemoryIndexer()
|
||||
return _indexer
|
||||
750
backend/app/services/memory/indexing/retrieval.py
Normal file
750
backend/app/services/memory/indexing/retrieval.py
Normal file
@@ -0,0 +1,750 @@
|
||||
# app/services/memory/indexing/retrieval.py
|
||||
"""
|
||||
Memory Retrieval Engine.
|
||||
|
||||
Provides hybrid retrieval capabilities combining:
|
||||
- Vector similarity search
|
||||
- Temporal filtering
|
||||
- Entity filtering
|
||||
- Outcome filtering
|
||||
- Relevance scoring
|
||||
- Result caching
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import (
|
||||
Episode,
|
||||
Fact,
|
||||
MemoryType,
|
||||
Outcome,
|
||||
Procedure,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
from .index import (
|
||||
MemoryIndexer,
|
||||
get_memory_indexer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", Episode, Fact, Procedure)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalQuery:
|
||||
"""Query parameters for memory retrieval."""
|
||||
|
||||
# Text/semantic query
|
||||
query_text: str | None = None
|
||||
query_embedding: list[float] | None = None
|
||||
|
||||
# Temporal filters
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
recent_seconds: float | None = None
|
||||
|
||||
# Entity filters
|
||||
entities: list[tuple[str, str]] | None = None
|
||||
entity_match_all: bool = False
|
||||
|
||||
# Outcome filters
|
||||
outcomes: list[Outcome] | None = None
|
||||
|
||||
# Memory type filter
|
||||
memory_types: list[MemoryType] | None = None
|
||||
|
||||
# Result options
|
||||
limit: int = 10
|
||||
min_relevance: float = 0.0
|
||||
|
||||
# Retrieval mode
|
||||
use_vector: bool = True
|
||||
use_temporal: bool = True
|
||||
use_entity: bool = True
|
||||
use_outcome: bool = True
|
||||
|
||||
def to_cache_key(self) -> str:
|
||||
"""Generate a cache key for this query."""
|
||||
key_parts = [
|
||||
self.query_text or "",
|
||||
str(self.start_time),
|
||||
str(self.end_time),
|
||||
str(self.recent_seconds),
|
||||
str(self.entities),
|
||||
str(self.outcomes),
|
||||
str(self.memory_types),
|
||||
str(self.limit),
|
||||
str(self.min_relevance),
|
||||
]
|
||||
key_string = "|".join(key_parts)
|
||||
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoredResult:
|
||||
"""A retrieval result with relevance score."""
|
||||
|
||||
memory_id: UUID
|
||||
memory_type: MemoryType
|
||||
relevance_score: float
|
||||
score_breakdown: dict[str, float] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cached retrieval result."""
|
||||
|
||||
results: list[ScoredResult]
|
||||
created_at: datetime
|
||||
ttl_seconds: float
|
||||
query_key: str
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if this cache entry has expired."""
|
||||
age = (_utcnow() - self.created_at).total_seconds()
|
||||
return age > self.ttl_seconds
|
||||
|
||||
|
||||
class RelevanceScorer:
|
||||
"""
|
||||
Calculates relevance scores for retrieved memories.
|
||||
|
||||
Combines multiple signals:
|
||||
- Vector similarity (if available)
|
||||
- Temporal recency
|
||||
- Entity match count
|
||||
- Outcome preference
|
||||
- Importance/confidence
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_weight: float = 0.4,
|
||||
recency_weight: float = 0.2,
|
||||
entity_weight: float = 0.2,
|
||||
outcome_weight: float = 0.1,
|
||||
importance_weight: float = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the relevance scorer.
|
||||
|
||||
Args:
|
||||
vector_weight: Weight for vector similarity (0-1)
|
||||
recency_weight: Weight for temporal recency (0-1)
|
||||
entity_weight: Weight for entity matches (0-1)
|
||||
outcome_weight: Weight for outcome preference (0-1)
|
||||
importance_weight: Weight for importance score (0-1)
|
||||
"""
|
||||
total = (
|
||||
vector_weight
|
||||
+ recency_weight
|
||||
+ entity_weight
|
||||
+ outcome_weight
|
||||
+ importance_weight
|
||||
)
|
||||
# Normalize weights
|
||||
self.vector_weight = vector_weight / total
|
||||
self.recency_weight = recency_weight / total
|
||||
self.entity_weight = entity_weight / total
|
||||
self.outcome_weight = outcome_weight / total
|
||||
self.importance_weight = importance_weight / total
|
||||
|
||||
def score(
|
||||
self,
|
||||
memory_id: UUID,
|
||||
memory_type: MemoryType,
|
||||
vector_similarity: float | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
entity_match_count: int = 0,
|
||||
entity_total: int = 1,
|
||||
outcome: Outcome | None = None,
|
||||
importance: float = 0.5,
|
||||
preferred_outcomes: list[Outcome] | None = None,
|
||||
) -> ScoredResult:
|
||||
"""
|
||||
Calculate a relevance score for a memory.
|
||||
|
||||
Args:
|
||||
memory_id: ID of the memory
|
||||
memory_type: Type of memory
|
||||
vector_similarity: Similarity score from vector search (0-1)
|
||||
timestamp: Timestamp of the memory
|
||||
entity_match_count: Number of matching entities
|
||||
entity_total: Total entities in query
|
||||
outcome: Outcome of the memory
|
||||
importance: Importance score of the memory (0-1)
|
||||
preferred_outcomes: Outcomes to prefer
|
||||
|
||||
Returns:
|
||||
Scored result with breakdown
|
||||
"""
|
||||
breakdown: dict[str, float] = {}
|
||||
|
||||
# Vector similarity score
|
||||
if vector_similarity is not None:
|
||||
breakdown["vector"] = vector_similarity
|
||||
else:
|
||||
breakdown["vector"] = 0.5 # Neutral if no vector
|
||||
|
||||
# Recency score (exponential decay)
|
||||
if timestamp:
|
||||
age_hours = (_utcnow() - timestamp).total_seconds() / 3600
|
||||
# Decay with half-life of 24 hours
|
||||
breakdown["recency"] = 2 ** (-age_hours / 24)
|
||||
else:
|
||||
breakdown["recency"] = 0.5
|
||||
|
||||
# Entity match score
|
||||
if entity_total > 0:
|
||||
breakdown["entity"] = entity_match_count / entity_total
|
||||
else:
|
||||
breakdown["entity"] = 1.0 # No entity filter = full score
|
||||
|
||||
# Outcome score
|
||||
if preferred_outcomes and outcome:
|
||||
breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0
|
||||
else:
|
||||
breakdown["outcome"] = 0.5 # Neutral if no preference
|
||||
|
||||
# Importance score
|
||||
breakdown["importance"] = importance
|
||||
|
||||
# Calculate weighted sum
|
||||
total_score = (
|
||||
breakdown["vector"] * self.vector_weight
|
||||
+ breakdown["recency"] * self.recency_weight
|
||||
+ breakdown["entity"] * self.entity_weight
|
||||
+ breakdown["outcome"] * self.outcome_weight
|
||||
+ breakdown["importance"] * self.importance_weight
|
||||
)
|
||||
|
||||
return ScoredResult(
|
||||
memory_id=memory_id,
|
||||
memory_type=memory_type,
|
||||
relevance_score=total_score,
|
||||
score_breakdown=breakdown,
|
||||
)
|
||||
|
||||
|
||||
class RetrievalCache:
|
||||
"""
|
||||
In-memory cache for retrieval results.
|
||||
|
||||
Supports TTL-based expiration and LRU eviction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_entries: int = 1000,
|
||||
default_ttl_seconds: float = 300,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the cache.
|
||||
|
||||
Args:
|
||||
max_entries: Maximum cache entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
"""
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
self._max_entries = max_entries
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._access_order: list[str] = []
|
||||
logger.info(
|
||||
f"Initialized RetrievalCache with max_entries={max_entries}, "
|
||||
f"ttl={default_ttl_seconds}s"
|
||||
)
|
||||
|
||||
def get(self, query_key: str) -> list[ScoredResult] | None:
|
||||
"""
|
||||
Get cached results for a query.
|
||||
|
||||
Args:
|
||||
query_key: Cache key for the query
|
||||
|
||||
Returns:
|
||||
Cached results or None if not found/expired
|
||||
"""
|
||||
if query_key not in self._cache:
|
||||
return None
|
||||
|
||||
entry = self._cache[query_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[query_key]
|
||||
if query_key in self._access_order:
|
||||
self._access_order.remove(query_key)
|
||||
return None
|
||||
|
||||
# Update access order (LRU)
|
||||
if query_key in self._access_order:
|
||||
self._access_order.remove(query_key)
|
||||
self._access_order.append(query_key)
|
||||
|
||||
logger.debug(f"Cache hit for {query_key}")
|
||||
return entry.results
|
||||
|
||||
def put(
|
||||
self,
|
||||
query_key: str,
|
||||
results: list[ScoredResult],
|
||||
ttl_seconds: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Cache results for a query.
|
||||
|
||||
Args:
|
||||
query_key: Cache key for the query
|
||||
results: Results to cache
|
||||
ttl_seconds: TTL for this entry (or default)
|
||||
"""
|
||||
# Evict if at capacity
|
||||
while len(self._cache) >= self._max_entries and self._access_order:
|
||||
oldest_key = self._access_order.pop(0)
|
||||
if oldest_key in self._cache:
|
||||
del self._cache[oldest_key]
|
||||
|
||||
entry = CacheEntry(
|
||||
results=results,
|
||||
created_at=_utcnow(),
|
||||
ttl_seconds=ttl_seconds or self._default_ttl,
|
||||
query_key=query_key,
|
||||
)
|
||||
|
||||
self._cache[query_key] = entry
|
||||
self._access_order.append(query_key)
|
||||
logger.debug(f"Cached {len(results)} results for {query_key}")
|
||||
|
||||
def invalidate(self, query_key: str) -> bool:
|
||||
"""
|
||||
Invalidate a specific cache entry.
|
||||
|
||||
Args:
|
||||
query_key: Cache key to invalidate
|
||||
|
||||
Returns:
|
||||
True if entry was found and removed
|
||||
"""
|
||||
if query_key in self._cache:
|
||||
del self._cache[query_key]
|
||||
if query_key in self._access_order:
|
||||
self._access_order.remove(query_key)
|
||||
return True
|
||||
return False
|
||||
|
||||
def invalidate_by_memory(self, memory_id: UUID) -> int:
|
||||
"""
|
||||
Invalidate all cache entries containing a specific memory.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
keys_to_remove = []
|
||||
for key, entry in self._cache.items():
|
||||
if any(r.memory_id == memory_id for r in entry.results):
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
self.invalidate(key)
|
||||
|
||||
if keys_to_remove:
|
||||
logger.debug(
|
||||
f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}"
|
||||
)
|
||||
return len(keys_to_remove)
|
||||
|
||||
def clear(self) -> int:
|
||||
"""
|
||||
Clear all cache entries.
|
||||
|
||||
Returns:
|
||||
Number of entries cleared
|
||||
"""
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._access_order.clear()
|
||||
logger.info(f"Cleared {count} cache entries")
|
||||
return count
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
expired_count = sum(1 for e in self._cache.values() if e.is_expired())
|
||||
return {
|
||||
"total_entries": len(self._cache),
|
||||
"expired_entries": expired_count,
|
||||
"max_entries": self._max_entries,
|
||||
"default_ttl_seconds": self._default_ttl,
|
||||
}
|
||||
|
||||
|
||||
class RetrievalEngine:
|
||||
"""
|
||||
Hybrid retrieval engine for memory search.
|
||||
|
||||
Combines multiple index types for comprehensive retrieval:
|
||||
- Vector search for semantic similarity
|
||||
- Temporal index for time-based filtering
|
||||
- Entity index for entity-based lookups
|
||||
- Outcome index for success/failure filtering
|
||||
|
||||
Results are scored and ranked using relevance scoring.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
indexer: MemoryIndexer | None = None,
|
||||
scorer: RelevanceScorer | None = None,
|
||||
cache: RetrievalCache | None = None,
|
||||
enable_cache: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the retrieval engine.
|
||||
|
||||
Args:
|
||||
indexer: Memory indexer (defaults to singleton)
|
||||
scorer: Relevance scorer (defaults to new instance)
|
||||
cache: Retrieval cache (defaults to new instance)
|
||||
enable_cache: Whether to enable result caching
|
||||
"""
|
||||
self._indexer = indexer or get_memory_indexer()
|
||||
self._scorer = scorer or RelevanceScorer()
|
||||
self._cache = cache or RetrievalCache() if enable_cache else None
|
||||
self._enable_cache = enable_cache
|
||||
logger.info(f"Initialized RetrievalEngine with cache={enable_cache}")
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: RetrievalQuery,
|
||||
use_cache: bool = True,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve relevant memories using hybrid search.
|
||||
|
||||
Args:
|
||||
query: Retrieval query parameters
|
||||
use_cache: Whether to use cached results
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
start_time = _utcnow()
|
||||
|
||||
# Check cache
|
||||
cache_key = query.to_cache_key()
|
||||
if use_cache and self._cache:
|
||||
cached = self._cache.get(cache_key)
|
||||
if cached:
|
||||
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||
return RetrievalResult(
|
||||
items=cached,
|
||||
total_count=len(cached),
|
||||
query=query.query_text or "",
|
||||
retrieval_type="cached",
|
||||
latency_ms=latency,
|
||||
metadata={"cache_hit": True},
|
||||
)
|
||||
|
||||
# Collect candidates from each index
|
||||
candidates: dict[UUID, dict[str, Any]] = {}
|
||||
|
||||
# Vector search
|
||||
if query.use_vector and query.query_embedding:
|
||||
vector_results = await self._indexer.vector_index.search(
|
||||
query=query.query_embedding,
|
||||
limit=query.limit * 3, # Get more for filtering
|
||||
min_similarity=query.min_relevance,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for entry in vector_results:
|
||||
if entry.memory_id not in candidates:
|
||||
candidates[entry.memory_id] = {
|
||||
"memory_type": entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get(
|
||||
"similarity", 0.5
|
||||
)
|
||||
candidates[entry.memory_id]["sources"].append("vector")
|
||||
|
||||
# Temporal search
|
||||
if query.use_temporal and (
|
||||
query.start_time or query.end_time or query.recent_seconds
|
||||
):
|
||||
temporal_results = await self._indexer.temporal_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
start_time=query.start_time,
|
||||
end_time=query.end_time,
|
||||
recent_seconds=query.recent_seconds,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for temporal_entry in temporal_results:
|
||||
if temporal_entry.memory_id not in candidates:
|
||||
candidates[temporal_entry.memory_id] = {
|
||||
"memory_type": temporal_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[temporal_entry.memory_id]["timestamp"] = (
|
||||
temporal_entry.timestamp
|
||||
)
|
||||
candidates[temporal_entry.memory_id]["sources"].append("temporal")
|
||||
|
||||
# Entity search
|
||||
if query.use_entity and query.entities:
|
||||
entity_results = await self._indexer.entity_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
entities=query.entities,
|
||||
match_all=query.entity_match_all,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for entity_entry in entity_results:
|
||||
if entity_entry.memory_id not in candidates:
|
||||
candidates[entity_entry.memory_id] = {
|
||||
"memory_type": entity_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
# Count entity matches
|
||||
entity_count = candidates[entity_entry.memory_id].get(
|
||||
"entity_match_count", 0
|
||||
)
|
||||
candidates[entity_entry.memory_id]["entity_match_count"] = (
|
||||
entity_count + 1
|
||||
)
|
||||
candidates[entity_entry.memory_id]["sources"].append("entity")
|
||||
|
||||
# Outcome search
|
||||
if query.use_outcome and query.outcomes:
|
||||
outcome_results = await self._indexer.outcome_index.search(
|
||||
query=None,
|
||||
limit=query.limit * 3,
|
||||
outcomes=query.outcomes,
|
||||
memory_type=query.memory_types[0] if query.memory_types else None,
|
||||
)
|
||||
for outcome_entry in outcome_results:
|
||||
if outcome_entry.memory_id not in candidates:
|
||||
candidates[outcome_entry.memory_id] = {
|
||||
"memory_type": outcome_entry.memory_type,
|
||||
"sources": [],
|
||||
}
|
||||
candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome
|
||||
candidates[outcome_entry.memory_id]["sources"].append("outcome")
|
||||
|
||||
# Score and rank candidates
|
||||
scored_results: list[ScoredResult] = []
|
||||
entity_total = len(query.entities) if query.entities else 1
|
||||
|
||||
for memory_id, data in candidates.items():
|
||||
scored = self._scorer.score(
|
||||
memory_id=memory_id,
|
||||
memory_type=data["memory_type"],
|
||||
vector_similarity=data.get("vector_similarity"),
|
||||
timestamp=data.get("timestamp"),
|
||||
entity_match_count=data.get("entity_match_count", 0),
|
||||
entity_total=entity_total,
|
||||
outcome=data.get("outcome"),
|
||||
preferred_outcomes=query.outcomes,
|
||||
)
|
||||
scored.metadata["sources"] = data.get("sources", [])
|
||||
|
||||
# Filter by minimum relevance
|
||||
if scored.relevance_score >= query.min_relevance:
|
||||
scored_results.append(scored)
|
||||
|
||||
# Sort by relevance score
|
||||
scored_results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
|
||||
# Apply limit
|
||||
final_results = scored_results[: query.limit]
|
||||
|
||||
# Cache results
|
||||
if use_cache and self._cache and final_results:
|
||||
self._cache.put(cache_key, final_results)
|
||||
|
||||
latency = (_utcnow() - start_time).total_seconds() * 1000
|
||||
|
||||
logger.info(
|
||||
f"Retrieved {len(final_results)} results from {len(candidates)} candidates "
|
||||
f"in {latency:.2f}ms"
|
||||
)
|
||||
|
||||
return RetrievalResult(
|
||||
items=final_results,
|
||||
total_count=len(candidates),
|
||||
query=query.query_text or "",
|
||||
retrieval_type="hybrid",
|
||||
latency_ms=latency,
|
||||
metadata={
|
||||
"cache_hit": False,
|
||||
"candidates_count": len(candidates),
|
||||
"filtered_count": len(scored_results),
|
||||
},
|
||||
)
|
||||
|
||||
async def retrieve_similar(
|
||||
self,
|
||||
embedding: list[float],
|
||||
limit: int = 10,
|
||||
min_similarity: float = 0.5,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve memories similar to a given embedding.
|
||||
|
||||
Args:
|
||||
embedding: Query embedding
|
||||
limit: Maximum results
|
||||
min_similarity: Minimum similarity threshold
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
query_embedding=embedding,
|
||||
limit=limit,
|
||||
min_relevance=min_similarity,
|
||||
memory_types=memory_types,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_recent(
|
||||
self,
|
||||
hours: float = 24,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve recent memories.
|
||||
|
||||
Args:
|
||||
hours: Number of hours to look back
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
recent_seconds=hours * 3600,
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_by_entity(
|
||||
self,
|
||||
entity_type: str,
|
||||
entity_value: str,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve memories by entity.
|
||||
|
||||
Args:
|
||||
entity_type: Type of entity
|
||||
entity_value: Entity value
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
entities=[(entity_type, entity_value)],
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_temporal=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
async def retrieve_successful(
|
||||
self,
|
||||
limit: int = 10,
|
||||
memory_types: list[MemoryType] | None = None,
|
||||
) -> RetrievalResult[ScoredResult]:
|
||||
"""
|
||||
Retrieve successful memories.
|
||||
|
||||
Args:
|
||||
limit: Maximum results
|
||||
memory_types: Filter by memory types
|
||||
|
||||
Returns:
|
||||
Retrieval result with scored items
|
||||
"""
|
||||
query = RetrievalQuery(
|
||||
outcomes=[Outcome.SUCCESS],
|
||||
limit=limit,
|
||||
memory_types=memory_types,
|
||||
use_vector=False,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
)
|
||||
return await self.retrieve(query)
|
||||
|
||||
def invalidate_cache(self) -> int:
|
||||
"""
|
||||
Invalidate all cached results.
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if self._cache:
|
||||
return self._cache.clear()
|
||||
return 0
|
||||
|
||||
def invalidate_cache_for_memory(self, memory_id: UUID) -> int:
|
||||
"""
|
||||
Invalidate cache entries containing a specific memory.
|
||||
|
||||
Args:
|
||||
memory_id: Memory ID to invalidate
|
||||
|
||||
Returns:
|
||||
Number of entries invalidated
|
||||
"""
|
||||
if self._cache:
|
||||
return self._cache.invalidate_by_memory(memory_id)
|
||||
return 0
|
||||
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
if self._cache:
|
||||
return self._cache.get_stats()
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
# Singleton retrieval engine instance
|
||||
_engine: RetrievalEngine | None = None
|
||||
|
||||
|
||||
def get_retrieval_engine() -> RetrievalEngine:
|
||||
"""Get the singleton retrieval engine instance."""
|
||||
global _engine
|
||||
if _engine is None:
|
||||
_engine = RetrievalEngine()
|
||||
return _engine
|
||||
2
backend/tests/unit/services/memory/indexing/__init__.py
Normal file
2
backend/tests/unit/services/memory/indexing/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/indexing/__init__.py
|
||||
"""Unit tests for memory indexing."""
|
||||
497
backend/tests/unit/services/memory/indexing/test_index.py
Normal file
497
backend/tests/unit/services/memory/indexing/test_index.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# tests/unit/services/memory/indexing/test_index.py
|
||||
"""Unit tests for memory indexing."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.indexing.index import (
|
||||
EntityIndex,
|
||||
MemoryIndexer,
|
||||
OutcomeIndex,
|
||||
TemporalIndex,
|
||||
VectorIndex,
|
||||
get_memory_indexer,
|
||||
)
|
||||
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def make_episode(
|
||||
embedding: list[float] | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
occurred_at: datetime | None = None,
|
||||
) -> Episode:
|
||||
"""Create a test episode."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type="test_task",
|
||||
task_description="Test task description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details="Test outcome",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=["lesson1"],
|
||||
importance_score=0.8,
|
||||
embedding=embedding,
|
||||
occurred_at=occurred_at or _utcnow(),
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def make_fact(
|
||||
embedding: list[float] | None = None,
|
||||
subject: str = "test_subject",
|
||||
predicate: str = "has_property",
|
||||
obj: str = "test_value",
|
||||
) -> Fact:
|
||||
"""Create a test fact."""
|
||||
return Fact(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
object=obj,
|
||||
confidence=0.9,
|
||||
source_episode_ids=[uuid4()],
|
||||
first_learned=_utcnow(),
|
||||
last_reinforced=_utcnow(),
|
||||
reinforcement_count=1,
|
||||
embedding=embedding,
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def make_procedure(
|
||||
embedding: list[float] | None = None,
|
||||
success_count: int = 8,
|
||||
failure_count: int = 2,
|
||||
) -> Procedure:
|
||||
"""Create a test procedure."""
|
||||
return Procedure(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
name="test_procedure",
|
||||
trigger_pattern="test.*",
|
||||
steps=[{"step": 1, "action": "test"}],
|
||||
success_count=success_count,
|
||||
failure_count=failure_count,
|
||||
last_used=_utcnow(),
|
||||
embedding=embedding,
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
class TestVectorIndex:
|
||||
"""Tests for VectorIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> VectorIndex[Episode]:
|
||||
"""Create a vector index."""
|
||||
return VectorIndex[Episode](dimension=4)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test adding an item to the index."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
entry = await index.add(episode)
|
||||
|
||||
assert entry.memory_id == episode.id
|
||||
assert entry.memory_type == MemoryType.EPISODIC
|
||||
assert entry.dimension == 4
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_item(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test removing an item from the index."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await index.add(episode)
|
||||
|
||||
result = await index.remove(episode.id)
|
||||
|
||||
assert result is True
|
||||
assert await index.count() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_nonexistent(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test removing a nonexistent item."""
|
||||
result = await index.remove(uuid4())
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_similar(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test searching for similar items."""
|
||||
# Add items with different embeddings
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
|
||||
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||
|
||||
await index.add(e1)
|
||||
await index.add(e2)
|
||||
await index.add(e3)
|
||||
|
||||
# Search for similar to first
|
||||
results = await index.search([1.0, 0.0, 0.0, 0.0], limit=2)
|
||||
|
||||
assert len(results) == 2
|
||||
# First result should be most similar
|
||||
assert results[0].memory_id == e1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_min_similarity(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test minimum similarity threshold."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) # Orthogonal
|
||||
|
||||
await index.add(e1)
|
||||
await index.add(e2)
|
||||
|
||||
# Search with high threshold
|
||||
results = await index.search([1.0, 0.0, 0.0, 0.0], min_similarity=0.9)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == e1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_empty_query(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test search with empty query."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await index.add(e1)
|
||||
|
||||
results = await index.search([], limit=10)
|
||||
assert len(results) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear(self, index: VectorIndex[Episode]) -> None:
|
||||
"""Test clearing the index."""
|
||||
await index.add(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||
await index.add(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
|
||||
|
||||
count = await index.clear()
|
||||
|
||||
assert count == 2
|
||||
assert await index.count() == 0
|
||||
|
||||
|
||||
class TestTemporalIndex:
|
||||
"""Tests for TemporalIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> TemporalIndex[Episode]:
|
||||
"""Create a temporal index."""
|
||||
return TemporalIndex[Episode]()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test adding an item."""
|
||||
episode = make_episode()
|
||||
entry = await index.add(episode)
|
||||
|
||||
assert entry.memory_id == episode.id
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_time_range(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test searching by time range."""
|
||||
now = _utcnow()
|
||||
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
recent = make_episode(occurred_at=now - timedelta(hours=1))
|
||||
newest = make_episode(occurred_at=now)
|
||||
|
||||
await index.add(old)
|
||||
await index.add(recent)
|
||||
await index.add(newest)
|
||||
|
||||
# Search last hour
|
||||
results = await index.search(
|
||||
query=None,
|
||||
start_time=now - timedelta(hours=1, minutes=30),
|
||||
end_time=now,
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_recent(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test searching for recent items."""
|
||||
now = _utcnow()
|
||||
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
recent = make_episode(occurred_at=now - timedelta(minutes=30))
|
||||
|
||||
await index.add(old)
|
||||
await index.add(recent)
|
||||
|
||||
# Search last hour (3600 seconds)
|
||||
results = await index.search(query=None, recent_seconds=3600)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == recent.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_order(self, index: TemporalIndex[Episode]) -> None:
|
||||
"""Test result ordering."""
|
||||
now = _utcnow()
|
||||
e1 = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
e2 = make_episode(occurred_at=now - timedelta(hours=1))
|
||||
e3 = make_episode(occurred_at=now)
|
||||
|
||||
await index.add(e1)
|
||||
await index.add(e2)
|
||||
await index.add(e3)
|
||||
|
||||
# Descending order (newest first)
|
||||
results_desc = await index.search(query=None, order="desc", limit=10)
|
||||
assert results_desc[0].memory_id == e3.id
|
||||
|
||||
# Ascending order (oldest first)
|
||||
results_asc = await index.search(query=None, order="asc", limit=10)
|
||||
assert results_asc[0].memory_id == e1.id
|
||||
|
||||
|
||||
class TestEntityIndex:
|
||||
"""Tests for EntityIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> EntityIndex[Fact]:
|
||||
"""Create an entity index."""
|
||||
return EntityIndex[Fact]()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test adding an item."""
|
||||
fact = make_fact(subject="user", obj="admin")
|
||||
entry = await index.add(fact)
|
||||
|
||||
assert entry.memory_id == fact.id
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_entity(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test searching by entity."""
|
||||
f1 = make_fact(subject="user", obj="admin")
|
||||
f2 = make_fact(subject="system", obj="config")
|
||||
|
||||
await index.add(f1)
|
||||
await index.add(f2)
|
||||
|
||||
results = await index.search(
|
||||
query=None,
|
||||
entity_type="subject",
|
||||
entity_value="user",
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == f1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_multiple_entities(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test searching with multiple entities."""
|
||||
f1 = make_fact(subject="user", obj="admin")
|
||||
f2 = make_fact(subject="user", obj="guest")
|
||||
|
||||
await index.add(f1)
|
||||
await index.add(f2)
|
||||
|
||||
# Search for facts about "user" subject
|
||||
results = await index.search(
|
||||
query=None,
|
||||
entities=[("subject", "user")],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_match_all(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test matching all entities."""
|
||||
f1 = make_fact(subject="user", obj="admin")
|
||||
f2 = make_fact(subject="user", obj="guest")
|
||||
|
||||
await index.add(f1)
|
||||
await index.add(f2)
|
||||
|
||||
# Search for user+admin (match all)
|
||||
results = await index.search(
|
||||
query=None,
|
||||
entities=[("subject", "user"), ("object", "admin")],
|
||||
match_all=True,
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == f1.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_entities(self, index: EntityIndex[Fact]) -> None:
|
||||
"""Test getting entities for a memory."""
|
||||
fact = make_fact(subject="user", obj="admin")
|
||||
await index.add(fact)
|
||||
|
||||
entities = await index.get_entities(fact.id)
|
||||
|
||||
assert ("subject", "user") in entities
|
||||
assert ("object", "admin") in entities
|
||||
|
||||
|
||||
class TestOutcomeIndex:
|
||||
"""Tests for OutcomeIndex."""
|
||||
|
||||
@pytest.fixture
|
||||
def index(self) -> OutcomeIndex[Episode]:
|
||||
"""Create an outcome index."""
|
||||
return OutcomeIndex[Episode]()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_item(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test adding an item."""
|
||||
episode = make_episode(outcome=Outcome.SUCCESS)
|
||||
entry = await index.add(episode)
|
||||
|
||||
assert entry.memory_id == episode.id
|
||||
assert entry.outcome == Outcome.SUCCESS
|
||||
assert await index.count() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_outcome(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test searching by outcome."""
|
||||
success = make_episode(outcome=Outcome.SUCCESS)
|
||||
failure = make_episode(outcome=Outcome.FAILURE)
|
||||
|
||||
await index.add(success)
|
||||
await index.add(failure)
|
||||
|
||||
results = await index.search(query=None, outcome=Outcome.SUCCESS)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].memory_id == success.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_multiple_outcomes(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test searching with multiple outcomes."""
|
||||
success = make_episode(outcome=Outcome.SUCCESS)
|
||||
partial = make_episode(outcome=Outcome.PARTIAL)
|
||||
failure = make_episode(outcome=Outcome.FAILURE)
|
||||
|
||||
await index.add(success)
|
||||
await index.add(partial)
|
||||
await index.add(failure)
|
||||
|
||||
results = await index.search(
|
||||
query=None,
|
||||
outcomes=[Outcome.SUCCESS, Outcome.PARTIAL],
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_outcome_stats(self, index: OutcomeIndex[Episode]) -> None:
|
||||
"""Test getting outcome statistics."""
|
||||
await index.add(make_episode(outcome=Outcome.SUCCESS))
|
||||
await index.add(make_episode(outcome=Outcome.SUCCESS))
|
||||
await index.add(make_episode(outcome=Outcome.FAILURE))
|
||||
|
||||
stats = await index.get_outcome_stats()
|
||||
|
||||
assert stats[Outcome.SUCCESS] == 2
|
||||
assert stats[Outcome.FAILURE] == 1
|
||||
assert stats[Outcome.PARTIAL] == 0
|
||||
|
||||
|
||||
class TestMemoryIndexer:
|
||||
"""Tests for MemoryIndexer."""
|
||||
|
||||
@pytest.fixture
|
||||
def indexer(self) -> MemoryIndexer:
|
||||
"""Create a memory indexer."""
|
||||
return MemoryIndexer()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_episode(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test indexing an episode."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
results = await indexer.index(episode)
|
||||
|
||||
assert "vector" in results
|
||||
assert "temporal" in results
|
||||
assert "entity" in results
|
||||
assert "outcome" in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_fact(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test indexing a fact."""
|
||||
fact = make_fact(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
|
||||
results = await indexer.index(fact)
|
||||
|
||||
# Facts don't have outcomes
|
||||
assert "vector" in results
|
||||
assert "temporal" in results
|
||||
assert "entity" in results
|
||||
assert "outcome" not in results
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_from_all(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test removing from all indices."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await indexer.index(episode)
|
||||
|
||||
results = await indexer.remove(episode.id)
|
||||
|
||||
assert results["vector"] is True
|
||||
assert results["temporal"] is True
|
||||
assert results["entity"] is True
|
||||
assert results["outcome"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_all(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test clearing all indices."""
|
||||
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||
await indexer.index(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
|
||||
|
||||
counts = await indexer.clear_all()
|
||||
|
||||
assert counts["vector"] == 2
|
||||
assert counts["temporal"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats(self, indexer: MemoryIndexer) -> None:
|
||||
"""Test getting index statistics."""
|
||||
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
|
||||
|
||||
stats = await indexer.get_stats()
|
||||
|
||||
assert stats["vector"] == 1
|
||||
assert stats["temporal"] == 1
|
||||
assert stats["entity"] == 1
|
||||
assert stats["outcome"] == 1
|
||||
|
||||
|
||||
class TestGetMemoryIndexer:
|
||||
"""Tests for singleton getter."""
|
||||
|
||||
def test_returns_instance(self) -> None:
|
||||
"""Test that getter returns instance."""
|
||||
indexer = get_memory_indexer()
|
||||
assert indexer is not None
|
||||
assert isinstance(indexer, MemoryIndexer)
|
||||
|
||||
def test_returns_same_instance(self) -> None:
|
||||
"""Test that getter returns same instance."""
|
||||
indexer1 = get_memory_indexer()
|
||||
indexer2 = get_memory_indexer()
|
||||
assert indexer1 is indexer2
|
||||
450
backend/tests/unit/services/memory/indexing/test_retrieval.py
Normal file
450
backend/tests/unit/services/memory/indexing/test_retrieval.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# tests/unit/services/memory/indexing/test_retrieval.py
|
||||
"""Unit tests for memory retrieval."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.indexing.index import MemoryIndexer
|
||||
from app.services.memory.indexing.retrieval import (
|
||||
RelevanceScorer,
|
||||
RetrievalCache,
|
||||
RetrievalEngine,
|
||||
RetrievalQuery,
|
||||
ScoredResult,
|
||||
get_retrieval_engine,
|
||||
)
|
||||
from app.services.memory.types import Episode, MemoryType, Outcome
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def make_episode(
|
||||
embedding: list[float] | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
occurred_at: datetime | None = None,
|
||||
task_type: str = "test_task",
|
||||
) -> Episode:
|
||||
"""Create a test episode."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type=task_type,
|
||||
task_description="Test task description",
|
||||
actions=[{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details="Test outcome",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=["lesson1"],
|
||||
importance_score=0.8,
|
||||
embedding=embedding,
|
||||
occurred_at=occurred_at or _utcnow(),
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
class TestRetrievalQuery:
|
||||
"""Tests for RetrievalQuery."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default query values."""
|
||||
query = RetrievalQuery()
|
||||
|
||||
assert query.query_text is None
|
||||
assert query.limit == 10
|
||||
assert query.min_relevance == 0.0
|
||||
assert query.use_vector is True
|
||||
assert query.use_temporal is True
|
||||
|
||||
def test_cache_key_generation(self) -> None:
|
||||
"""Test cache key generation."""
|
||||
query1 = RetrievalQuery(query_text="test", limit=10)
|
||||
query2 = RetrievalQuery(query_text="test", limit=10)
|
||||
query3 = RetrievalQuery(query_text="different", limit=10)
|
||||
|
||||
# Same queries should have same key
|
||||
assert query1.to_cache_key() == query2.to_cache_key()
|
||||
# Different queries should have different keys
|
||||
assert query1.to_cache_key() != query3.to_cache_key()
|
||||
|
||||
|
||||
class TestScoredResult:
|
||||
"""Tests for ScoredResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a scored result."""
|
||||
result = ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.85,
|
||||
score_breakdown={"vector": 0.9, "recency": 0.8},
|
||||
)
|
||||
|
||||
assert result.relevance_score == 0.85
|
||||
assert result.score_breakdown["vector"] == 0.9
|
||||
|
||||
|
||||
class TestRelevanceScorer:
|
||||
"""Tests for RelevanceScorer."""
|
||||
|
||||
@pytest.fixture
|
||||
def scorer(self) -> RelevanceScorer:
|
||||
"""Create a relevance scorer."""
|
||||
return RelevanceScorer()
|
||||
|
||||
def test_score_with_vector(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with vector similarity."""
|
||||
result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
vector_similarity=0.9,
|
||||
)
|
||||
|
||||
assert result.relevance_score > 0
|
||||
assert result.score_breakdown["vector"] == 0.9
|
||||
|
||||
def test_score_with_recency(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with recency."""
|
||||
recent_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
timestamp=_utcnow(),
|
||||
)
|
||||
|
||||
old_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
timestamp=_utcnow() - timedelta(days=7),
|
||||
)
|
||||
|
||||
# Recent should have higher recency score
|
||||
assert (
|
||||
recent_result.score_breakdown["recency"]
|
||||
> old_result.score_breakdown["recency"]
|
||||
)
|
||||
|
||||
def test_score_with_outcome_preference(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with outcome preference."""
|
||||
success_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
outcome=Outcome.SUCCESS,
|
||||
preferred_outcomes=[Outcome.SUCCESS],
|
||||
)
|
||||
|
||||
failure_result = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
outcome=Outcome.FAILURE,
|
||||
preferred_outcomes=[Outcome.SUCCESS],
|
||||
)
|
||||
|
||||
assert success_result.score_breakdown["outcome"] == 1.0
|
||||
assert failure_result.score_breakdown["outcome"] == 0.0
|
||||
|
||||
def test_score_with_entity_match(self, scorer: RelevanceScorer) -> None:
|
||||
"""Test scoring with entity matches."""
|
||||
full_match = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
entity_match_count=3,
|
||||
entity_total=3,
|
||||
)
|
||||
|
||||
partial_match = scorer.score(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
entity_match_count=1,
|
||||
entity_total=3,
|
||||
)
|
||||
|
||||
assert (
|
||||
full_match.score_breakdown["entity"]
|
||||
> partial_match.score_breakdown["entity"]
|
||||
)
|
||||
|
||||
|
||||
class TestRetrievalCache:
|
||||
"""Tests for RetrievalCache."""
|
||||
|
||||
@pytest.fixture
|
||||
def cache(self) -> RetrievalCache:
|
||||
"""Create a retrieval cache."""
|
||||
return RetrievalCache(max_entries=10, default_ttl_seconds=60)
|
||||
|
||||
def test_put_and_get(self, cache: RetrievalCache) -> None:
|
||||
"""Test putting and getting from cache."""
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("test_key", results)
|
||||
cached = cache.get("test_key")
|
||||
|
||||
assert cached is not None
|
||||
assert len(cached) == 1
|
||||
|
||||
def test_get_nonexistent(self, cache: RetrievalCache) -> None:
|
||||
"""Test getting nonexistent entry."""
|
||||
result = cache.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_lru_eviction(self) -> None:
|
||||
"""Test LRU eviction when at capacity."""
|
||||
cache = RetrievalCache(max_entries=2, default_ttl_seconds=60)
|
||||
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("key1", results)
|
||||
cache.put("key2", results)
|
||||
cache.put("key3", results) # Should evict key1
|
||||
|
||||
assert cache.get("key1") is None
|
||||
assert cache.get("key2") is not None
|
||||
assert cache.get("key3") is not None
|
||||
|
||||
def test_invalidate(self, cache: RetrievalCache) -> None:
|
||||
"""Test invalidating a cache entry."""
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("test_key", results)
|
||||
removed = cache.invalidate("test_key")
|
||||
|
||||
assert removed is True
|
||||
assert cache.get("test_key") is None
|
||||
|
||||
def test_invalidate_by_memory(self, cache: RetrievalCache) -> None:
|
||||
"""Test invalidating by memory ID."""
|
||||
memory_id = uuid4()
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=memory_id,
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("key1", results)
|
||||
cache.put("key2", results)
|
||||
|
||||
count = cache.invalidate_by_memory(memory_id)
|
||||
|
||||
assert count == 2
|
||||
assert cache.get("key1") is None
|
||||
assert cache.get("key2") is None
|
||||
|
||||
def test_clear(self, cache: RetrievalCache) -> None:
|
||||
"""Test clearing the cache."""
|
||||
results = [
|
||||
ScoredResult(
|
||||
memory_id=uuid4(),
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
relevance_score=0.8,
|
||||
)
|
||||
]
|
||||
|
||||
cache.put("key1", results)
|
||||
cache.put("key2", results)
|
||||
|
||||
count = cache.clear()
|
||||
|
||||
assert count == 2
|
||||
assert cache.get("key1") is None
|
||||
|
||||
def test_get_stats(self, cache: RetrievalCache) -> None:
|
||||
"""Test getting cache statistics."""
|
||||
stats = cache.get_stats()
|
||||
|
||||
assert "total_entries" in stats
|
||||
assert "max_entries" in stats
|
||||
assert stats["max_entries"] == 10
|
||||
|
||||
|
||||
class TestRetrievalEngine:
|
||||
"""Tests for RetrievalEngine."""
|
||||
|
||||
@pytest.fixture
|
||||
def indexer(self) -> MemoryIndexer:
|
||||
"""Create a memory indexer."""
|
||||
return MemoryIndexer()
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self, indexer: MemoryIndexer) -> RetrievalEngine:
|
||||
"""Create a retrieval engine."""
|
||||
return RetrievalEngine(indexer=indexer, enable_cache=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_vector(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval by vector similarity."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
|
||||
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||
|
||||
await indexer.index(e1)
|
||||
await indexer.index(e2)
|
||||
await indexer.index(e3)
|
||||
|
||||
query = RetrievalQuery(
|
||||
query_embedding=[1.0, 0.0, 0.0, 0.0],
|
||||
limit=2,
|
||||
use_temporal=False,
|
||||
use_entity=False,
|
||||
use_outcome=False,
|
||||
)
|
||||
|
||||
result = await engine.retrieve(query)
|
||||
|
||||
assert len(result.items) > 0
|
||||
assert result.retrieval_type == "hybrid"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_recent(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval of recent items."""
|
||||
now = _utcnow()
|
||||
old = make_episode(occurred_at=now - timedelta(hours=2))
|
||||
recent = make_episode(occurred_at=now - timedelta(minutes=30))
|
||||
|
||||
await indexer.index(old)
|
||||
await indexer.index(recent)
|
||||
|
||||
result = await engine.retrieve_recent(hours=1)
|
||||
|
||||
assert len(result.items) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_by_entity(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval by entity."""
|
||||
e1 = make_episode(task_type="deploy")
|
||||
e2 = make_episode(task_type="test")
|
||||
|
||||
await indexer.index(e1)
|
||||
await indexer.index(e2)
|
||||
|
||||
result = await engine.retrieve_by_entity("task_type", "deploy")
|
||||
|
||||
assert len(result.items) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_successful(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieval of successful items."""
|
||||
success = make_episode(outcome=Outcome.SUCCESS)
|
||||
failure = make_episode(outcome=Outcome.FAILURE)
|
||||
|
||||
await indexer.index(success)
|
||||
await indexer.index(failure)
|
||||
|
||||
result = await engine.retrieve_successful()
|
||||
|
||||
assert len(result.items) == 1
|
||||
# Check outcome index was used
|
||||
assert result.items[0].memory_id == success.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_with_cache(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test that retrieval uses cache."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await indexer.index(episode)
|
||||
|
||||
query = RetrievalQuery(
|
||||
query_embedding=[1.0, 0.0, 0.0, 0.0],
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# First retrieval
|
||||
result1 = await engine.retrieve(query)
|
||||
assert result1.metadata.get("cache_hit") is False
|
||||
|
||||
# Second retrieval should be cached
|
||||
result2 = await engine.retrieve(query)
|
||||
assert result2.metadata.get("cache_hit") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalidate_cache(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test cache invalidation."""
|
||||
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await indexer.index(episode)
|
||||
|
||||
query = RetrievalQuery(query_embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
await engine.retrieve(query)
|
||||
|
||||
count = engine.invalidate_cache()
|
||||
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_similar(
|
||||
self, engine: RetrievalEngine, indexer: MemoryIndexer
|
||||
) -> None:
|
||||
"""Test retrieve_similar convenience method."""
|
||||
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
|
||||
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
|
||||
|
||||
await indexer.index(e1)
|
||||
await indexer.index(e2)
|
||||
|
||||
result = await engine.retrieve_similar(
|
||||
embedding=[1.0, 0.0, 0.0, 0.0],
|
||||
limit=1,
|
||||
)
|
||||
|
||||
assert len(result.items) == 1
|
||||
|
||||
def test_get_cache_stats(self, engine: RetrievalEngine) -> None:
|
||||
"""Test getting cache statistics."""
|
||||
stats = engine.get_cache_stats()
|
||||
|
||||
assert "total_entries" in stats
|
||||
|
||||
|
||||
class TestGetRetrievalEngine:
|
||||
"""Tests for singleton getter."""
|
||||
|
||||
def test_returns_instance(self) -> None:
|
||||
"""Test that getter returns instance."""
|
||||
engine = get_retrieval_engine()
|
||||
assert engine is not None
|
||||
assert isinstance(engine, RetrievalEngine)
|
||||
|
||||
def test_returns_same_instance(self) -> None:
|
||||
"""Test that getter returns same instance."""
|
||||
engine1 = get_retrieval_engine()
|
||||
engine2 = get_retrieval_engine()
|
||||
assert engine1 is engine2
|
||||
Reference in New Issue
Block a user