feat(memory): implement memory indexing and retrieval engine (#94)

Add comprehensive indexing and retrieval system for memory search:
- VectorIndex for semantic similarity search using cosine similarity
- TemporalIndex for time-based queries with range and recency support
- EntityIndex for entity-based lookups with multi-entity intersection
- OutcomeIndex for success/failure filtering on episodes
- MemoryIndexer as unified interface for all index types
- RetrievalEngine with hybrid search combining all indices
- RelevanceScorer for multi-signal relevance scoring
- RetrievalCache for LRU caching of search results

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 02:50:13 +01:00
parent 48ecb40f18
commit 999b7ac03f
6 changed files with 2602 additions and 3 deletions

View File

@@ -1,7 +1,56 @@
# app/services/memory/indexing/__init__.py
"""
Memory Indexing
Memory Indexing & Retrieval.
Vector embeddings and retrieval engine for memory search.
Provides vector embeddings and multiple index types for efficient memory search:
- Vector index for semantic similarity
- Temporal index for time-based queries
- Entity index for entity lookups
- Outcome index for success/failure filtering
"""
# Will be populated in #94
from .index import (
EntityIndex,
EntityIndexEntry,
IndexEntry,
MemoryIndex,
MemoryIndexer,
OutcomeIndex,
OutcomeIndexEntry,
TemporalIndex,
TemporalIndexEntry,
VectorIndex,
VectorIndexEntry,
get_memory_indexer,
)
from .retrieval import (
CacheEntry,
RelevanceScorer,
RetrievalCache,
RetrievalEngine,
RetrievalQuery,
ScoredResult,
get_retrieval_engine,
)
__all__ = [
"CacheEntry",
"EntityIndex",
"EntityIndexEntry",
"IndexEntry",
"MemoryIndex",
"MemoryIndexer",
"OutcomeIndex",
"OutcomeIndexEntry",
"RelevanceScorer",
"RetrievalCache",
"RetrievalEngine",
"RetrievalQuery",
"ScoredResult",
"TemporalIndex",
"TemporalIndexEntry",
"VectorIndex",
"VectorIndexEntry",
"get_memory_indexer",
"get_retrieval_engine",
]

View File

@@ -0,0 +1,851 @@
# app/services/memory/indexing/index.py
"""
Memory Indexing.
Provides multiple indexing strategies for efficient memory retrieval:
- Vector embeddings for semantic search
- Temporal index for time-based queries
- Entity index for entity-based lookups
- Outcome index for success/failure filtering
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import Any, TypeVar
from uuid import UUID
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
logger = logging.getLogger(__name__)
T = TypeVar("T", Episode, Fact, Procedure)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class IndexEntry:
"""A single entry in an index."""
memory_id: UUID
memory_type: MemoryType
indexed_at: datetime = field(default_factory=_utcnow)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class VectorIndexEntry(IndexEntry):
"""An entry with vector embedding."""
embedding: list[float] = field(default_factory=list)
dimension: int = 0
def __post_init__(self) -> None:
"""Set dimension from embedding."""
if self.embedding:
self.dimension = len(self.embedding)
@dataclass
class TemporalIndexEntry(IndexEntry):
"""An entry indexed by time."""
timestamp: datetime = field(default_factory=_utcnow)
@dataclass
class EntityIndexEntry(IndexEntry):
"""An entry indexed by entity."""
entity_type: str = ""
entity_value: str = ""
@dataclass
class OutcomeIndexEntry(IndexEntry):
"""An entry indexed by outcome."""
outcome: Outcome = Outcome.SUCCESS
class MemoryIndex[T](ABC):
"""Abstract base class for memory indices."""
@abstractmethod
async def add(self, item: T) -> IndexEntry:
"""Add an item to the index."""
...
@abstractmethod
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the index."""
...
@abstractmethod
async def search(
self,
query: Any,
limit: int = 10,
**kwargs: Any,
) -> list[IndexEntry]:
"""Search the index."""
...
@abstractmethod
async def clear(self) -> int:
"""Clear all entries from the index."""
...
@abstractmethod
async def count(self) -> int:
"""Get the number of entries in the index."""
...
class VectorIndex(MemoryIndex[T]):
"""
Vector-based index using embeddings for semantic similarity search.
Uses cosine similarity for matching.
"""
def __init__(self, dimension: int = 1536) -> None:
"""
Initialize the vector index.
Args:
dimension: Embedding dimension (default 1536 for OpenAI)
"""
self._dimension = dimension
self._entries: dict[UUID, VectorIndexEntry] = {}
logger.info(f"Initialized VectorIndex with dimension={dimension}")
async def add(self, item: T) -> VectorIndexEntry:
"""
Add an item to the vector index.
Args:
item: Memory item with embedding
Returns:
The created index entry
"""
embedding = getattr(item, "embedding", None) or []
entry = VectorIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
embedding=embedding,
dimension=len(embedding),
)
self._entries[item.id] = entry
logger.debug(f"Added {item.id} to vector index")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the vector index."""
if memory_id in self._entries:
del self._entries[memory_id]
logger.debug(f"Removed {memory_id} from vector index")
return True
return False
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
min_similarity: float = 0.0,
**kwargs: Any,
) -> list[VectorIndexEntry]:
"""
Search for similar items using vector similarity.
Args:
query: Query embedding vector
limit: Maximum results to return
min_similarity: Minimum similarity threshold (0-1)
**kwargs: Additional filter parameters
Returns:
List of matching entries sorted by similarity
"""
if not isinstance(query, list) or not query:
return []
results: list[tuple[float, VectorIndexEntry]] = []
for entry in self._entries.values():
if not entry.embedding:
continue
similarity = self._cosine_similarity(query, entry.embedding)
if similarity >= min_similarity:
results.append((similarity, entry))
# Sort by similarity descending
results.sort(key=lambda x: x[0], reverse=True)
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
if memory_type:
results = [(s, e) for s, e in results if e.memory_type == memory_type]
# Store similarity in metadata for the returned entries
output = []
for similarity, entry in results[:limit]:
entry.metadata["similarity"] = similarity
output.append(entry)
logger.debug(f"Vector search returned {len(output)} results")
return output
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
logger.info(f"Cleared {count} entries from vector index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
if len(a) != len(b) or len(a) == 0:
return 0.0
dot_product = sum(x * y for x, y in zip(a, b, strict=True))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot_product / (norm_a * norm_b)
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
class TemporalIndex(MemoryIndex[T]):
"""
Time-based index for efficient temporal queries.
Supports:
- Range queries (between timestamps)
- Recent items (within last N seconds/hours/days)
- Oldest/newest sorting
"""
def __init__(self) -> None:
"""Initialize the temporal index."""
self._entries: dict[UUID, TemporalIndexEntry] = {}
# Sorted list for efficient range queries
self._sorted_entries: list[tuple[datetime, UUID]] = []
logger.info("Initialized TemporalIndex")
async def add(self, item: T) -> TemporalIndexEntry:
"""
Add an item to the temporal index.
Args:
item: Memory item with timestamp
Returns:
The created index entry
"""
# Get timestamp from various possible fields
timestamp = self._get_timestamp(item)
entry = TemporalIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
timestamp=timestamp,
)
self._entries[item.id] = entry
self._insert_sorted(timestamp, item.id)
logger.debug(f"Added {item.id} to temporal index at {timestamp}")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the temporal index."""
if memory_id not in self._entries:
return False
self._entries.pop(memory_id)
self._sorted_entries = [
(ts, mid) for ts, mid in self._sorted_entries if mid != memory_id
]
logger.debug(f"Removed {memory_id} from temporal index")
return True
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
start_time: datetime | None = None,
end_time: datetime | None = None,
recent_seconds: float | None = None,
order: str = "desc",
**kwargs: Any,
) -> list[TemporalIndexEntry]:
"""
Search for items by time.
Args:
query: Ignored for temporal search
limit: Maximum results to return
start_time: Start of time range
end_time: End of time range
recent_seconds: Get items from last N seconds
order: Sort order ("asc" or "desc")
**kwargs: Additional filter parameters
Returns:
List of matching entries sorted by time
"""
if recent_seconds is not None:
start_time = _utcnow() - timedelta(seconds=recent_seconds)
end_time = _utcnow()
# Filter by time range
results: list[TemporalIndexEntry] = []
for entry in self._entries.values():
if start_time and entry.timestamp < start_time:
continue
if end_time and entry.timestamp > end_time:
continue
results.append(entry)
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
if memory_type:
results = [e for e in results if e.memory_type == memory_type]
# Sort by timestamp
results.sort(key=lambda e: e.timestamp, reverse=(order == "desc"))
logger.debug(f"Temporal search returned {min(len(results), limit)} results")
return results[:limit]
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
self._sorted_entries.clear()
logger.info(f"Cleared {count} entries from temporal index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None:
"""Insert entry maintaining sorted order."""
# Binary search insert for efficiency
low, high = 0, len(self._sorted_entries)
while low < high:
mid = (low + high) // 2
if self._sorted_entries[mid][0] < timestamp:
low = mid + 1
else:
high = mid
self._sorted_entries.insert(low, (timestamp, memory_id))
def _get_timestamp(self, item: T) -> datetime:
"""Get the relevant timestamp for an item."""
if hasattr(item, "occurred_at"):
return item.occurred_at
if hasattr(item, "first_learned"):
return item.first_learned
if hasattr(item, "last_used") and item.last_used:
return item.last_used
if hasattr(item, "created_at"):
return item.created_at
return _utcnow()
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
class EntityIndex(MemoryIndex[T]):
"""
Entity-based index for lookups by entities mentioned in memories.
Supports:
- Single entity lookup
- Multi-entity intersection
- Entity type filtering
"""
def __init__(self) -> None:
"""Initialize the entity index."""
# Main storage
self._entries: dict[UUID, EntityIndexEntry] = {}
# Inverted index: entity -> set of memory IDs
self._entity_to_memories: dict[str, set[UUID]] = {}
# Memory to entities mapping
self._memory_to_entities: dict[UUID, set[str]] = {}
logger.info("Initialized EntityIndex")
async def add(self, item: T) -> EntityIndexEntry:
"""
Add an item to the entity index.
Args:
item: Memory item with entity information
Returns:
The created index entry
"""
entities = self._extract_entities(item)
# Create entry for the primary entity (or first one)
primary_entity = entities[0] if entities else ("unknown", "unknown")
entry = EntityIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
entity_type=primary_entity[0],
entity_value=primary_entity[1],
)
self._entries[item.id] = entry
# Update inverted indices
entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities}
self._memory_to_entities[item.id] = entity_keys
for entity_key in entity_keys:
if entity_key not in self._entity_to_memories:
self._entity_to_memories[entity_key] = set()
self._entity_to_memories[entity_key].add(item.id)
logger.debug(f"Added {item.id} to entity index with {len(entities)} entities")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the entity index."""
if memory_id not in self._entries:
return False
# Remove from inverted index
if memory_id in self._memory_to_entities:
for entity_key in self._memory_to_entities[memory_id]:
if entity_key in self._entity_to_memories:
self._entity_to_memories[entity_key].discard(memory_id)
if not self._entity_to_memories[entity_key]:
del self._entity_to_memories[entity_key]
del self._memory_to_entities[memory_id]
del self._entries[memory_id]
logger.debug(f"Removed {memory_id} from entity index")
return True
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
entity_type: str | None = None,
entity_value: str | None = None,
entities: list[tuple[str, str]] | None = None,
match_all: bool = False,
**kwargs: Any,
) -> list[EntityIndexEntry]:
"""
Search for items by entity.
Args:
query: Entity value to search (if entity_type not specified)
limit: Maximum results to return
entity_type: Type of entity to filter
entity_value: Specific entity value
entities: List of (type, value) tuples to match
match_all: If True, require all entities to match
**kwargs: Additional filter parameters
Returns:
List of matching entries
"""
matching_ids: set[UUID] | None = None
# Handle single entity query
if entity_type and entity_value:
entities = [(entity_type, entity_value)]
elif entity_value is None and isinstance(query, str):
# Search across all entity types
entity_value = query
if entities:
for etype, evalue in entities:
entity_key = f"{etype}:{evalue}"
if entity_key in self._entity_to_memories:
ids = self._entity_to_memories[entity_key]
if matching_ids is None:
matching_ids = ids.copy()
elif match_all:
matching_ids &= ids
else:
matching_ids |= ids
elif match_all:
# Required entity not found
matching_ids = set()
break
elif entity_value:
# Search for value across all types
matching_ids = set()
for entity_key, ids in self._entity_to_memories.items():
if entity_value.lower() in entity_key.lower():
matching_ids |= ids
if matching_ids is None:
matching_ids = set(self._entries.keys())
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
results = []
for mid in matching_ids:
if mid in self._entries:
entry = self._entries[mid]
if memory_type and entry.memory_type != memory_type:
continue
results.append(entry)
logger.debug(f"Entity search returned {min(len(results), limit)} results")
return results[:limit]
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
self._entity_to_memories.clear()
self._memory_to_entities.clear()
logger.info(f"Cleared {count} entries from entity index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]:
"""Get all entities for a memory item."""
if memory_id not in self._memory_to_entities:
return []
entities = []
for entity_key in self._memory_to_entities[memory_id]:
if ":" in entity_key:
etype, evalue = entity_key.split(":", 1)
entities.append((etype, evalue))
return entities
def _extract_entities(self, item: T) -> list[tuple[str, str]]:
"""Extract entities from a memory item."""
entities: list[tuple[str, str]] = []
if isinstance(item, Episode):
# Extract from task type and context
entities.append(("task_type", item.task_type))
if item.project_id:
entities.append(("project", str(item.project_id)))
if item.agent_instance_id:
entities.append(("agent_instance", str(item.agent_instance_id)))
if item.agent_type_id:
entities.append(("agent_type", str(item.agent_type_id)))
elif isinstance(item, Fact):
# Subject and object are entities
entities.append(("subject", item.subject))
entities.append(("object", item.object))
if item.project_id:
entities.append(("project", str(item.project_id)))
elif isinstance(item, Procedure):
entities.append(("procedure", item.name))
if item.project_id:
entities.append(("project", str(item.project_id)))
if item.agent_type_id:
entities.append(("agent_type", str(item.agent_type_id)))
return entities
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
class OutcomeIndex(MemoryIndex[T]):
"""
Outcome-based index for filtering by success/failure.
Primarily used for episodes and procedures.
"""
def __init__(self) -> None:
"""Initialize the outcome index."""
self._entries: dict[UUID, OutcomeIndexEntry] = {}
# Inverted index by outcome
self._outcome_to_memories: dict[Outcome, set[UUID]] = {
Outcome.SUCCESS: set(),
Outcome.FAILURE: set(),
Outcome.PARTIAL: set(),
}
logger.info("Initialized OutcomeIndex")
async def add(self, item: T) -> OutcomeIndexEntry:
"""
Add an item to the outcome index.
Args:
item: Memory item with outcome information
Returns:
The created index entry
"""
outcome = self._get_outcome(item)
entry = OutcomeIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
outcome=outcome,
)
self._entries[item.id] = entry
self._outcome_to_memories[outcome].add(item.id)
logger.debug(f"Added {item.id} to outcome index with {outcome.value}")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the outcome index."""
if memory_id not in self._entries:
return False
entry = self._entries.pop(memory_id)
self._outcome_to_memories[entry.outcome].discard(memory_id)
logger.debug(f"Removed {memory_id} from outcome index")
return True
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
outcome: Outcome | None = None,
outcomes: list[Outcome] | None = None,
**kwargs: Any,
) -> list[OutcomeIndexEntry]:
"""
Search for items by outcome.
Args:
query: Ignored for outcome search
limit: Maximum results to return
outcome: Single outcome to filter
outcomes: Multiple outcomes to filter (OR)
**kwargs: Additional filter parameters
Returns:
List of matching entries
"""
if outcome:
outcomes = [outcome]
if outcomes:
matching_ids: set[UUID] = set()
for o in outcomes:
matching_ids |= self._outcome_to_memories.get(o, set())
else:
matching_ids = set(self._entries.keys())
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
results = []
for mid in matching_ids:
if mid in self._entries:
entry = self._entries[mid]
if memory_type and entry.memory_type != memory_type:
continue
results.append(entry)
logger.debug(f"Outcome search returned {min(len(results), limit)} results")
return results[:limit]
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
for outcome in self._outcome_to_memories:
self._outcome_to_memories[outcome].clear()
logger.info(f"Cleared {count} entries from outcome index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
async def get_outcome_stats(self) -> dict[Outcome, int]:
"""Get statistics on outcomes."""
return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()}
def _get_outcome(self, item: T) -> Outcome:
"""Get the outcome for an item."""
if isinstance(item, Episode):
return item.outcome
elif isinstance(item, Procedure):
# Derive from success rate
if item.success_rate >= 0.8:
return Outcome.SUCCESS
elif item.success_rate <= 0.2:
return Outcome.FAILURE
return Outcome.PARTIAL
return Outcome.SUCCESS
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
@dataclass
class MemoryIndexer:
"""
Unified indexer that manages all index types.
Provides a single interface for indexing and searching across
multiple index types.
"""
vector_index: VectorIndex[Any] = field(default_factory=VectorIndex)
temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex)
entity_index: EntityIndex[Any] = field(default_factory=EntityIndex)
outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex)
async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]:
"""
Index an item across all applicable indices.
Args:
item: Memory item to index
Returns:
Dictionary of index type to entry
"""
results: dict[str, IndexEntry] = {}
# Vector index (if embedding present)
if getattr(item, "embedding", None):
results["vector"] = await self.vector_index.add(item)
# Temporal index
results["temporal"] = await self.temporal_index.add(item)
# Entity index
results["entity"] = await self.entity_index.add(item)
# Outcome index (for episodes and procedures)
if isinstance(item, (Episode, Procedure)):
results["outcome"] = await self.outcome_index.add(item)
logger.info(
f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}"
)
return results
async def remove(self, memory_id: UUID) -> dict[str, bool]:
"""
Remove an item from all indices.
Args:
memory_id: ID of the memory to remove
Returns:
Dictionary of index type to removal success
"""
results = {
"vector": await self.vector_index.remove(memory_id),
"temporal": await self.temporal_index.remove(memory_id),
"entity": await self.entity_index.remove(memory_id),
"outcome": await self.outcome_index.remove(memory_id),
}
removed_from = [k for k, v in results.items() if v]
if removed_from:
logger.info(f"Removed {memory_id} from indices: {removed_from}")
return results
async def clear_all(self) -> dict[str, int]:
"""
Clear all indices.
Returns:
Dictionary of index type to count cleared
"""
return {
"vector": await self.vector_index.clear(),
"temporal": await self.temporal_index.clear(),
"entity": await self.entity_index.clear(),
"outcome": await self.outcome_index.clear(),
}
async def get_stats(self) -> dict[str, int]:
"""
Get statistics for all indices.
Returns:
Dictionary of index type to entry count
"""
return {
"vector": await self.vector_index.count(),
"temporal": await self.temporal_index.count(),
"entity": await self.entity_index.count(),
"outcome": await self.outcome_index.count(),
}
# Singleton indexer instance
_indexer: MemoryIndexer | None = None
def get_memory_indexer() -> MemoryIndexer:
"""Get the singleton memory indexer instance."""
global _indexer
if _indexer is None:
_indexer = MemoryIndexer()
return _indexer

View File

@@ -0,0 +1,750 @@
# app/services/memory/indexing/retrieval.py
"""
Memory Retrieval Engine.
Provides hybrid retrieval capabilities combining:
- Vector similarity search
- Temporal filtering
- Entity filtering
- Outcome filtering
- Relevance scoring
- Result caching
"""
import hashlib
import logging
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any, TypeVar
from uuid import UUID
from app.services.memory.types import (
Episode,
Fact,
MemoryType,
Outcome,
Procedure,
RetrievalResult,
)
from .index import (
MemoryIndexer,
get_memory_indexer,
)
logger = logging.getLogger(__name__)
T = TypeVar("T", Episode, Fact, Procedure)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class RetrievalQuery:
"""Query parameters for memory retrieval."""
# Text/semantic query
query_text: str | None = None
query_embedding: list[float] | None = None
# Temporal filters
start_time: datetime | None = None
end_time: datetime | None = None
recent_seconds: float | None = None
# Entity filters
entities: list[tuple[str, str]] | None = None
entity_match_all: bool = False
# Outcome filters
outcomes: list[Outcome] | None = None
# Memory type filter
memory_types: list[MemoryType] | None = None
# Result options
limit: int = 10
min_relevance: float = 0.0
# Retrieval mode
use_vector: bool = True
use_temporal: bool = True
use_entity: bool = True
use_outcome: bool = True
def to_cache_key(self) -> str:
"""Generate a cache key for this query."""
key_parts = [
self.query_text or "",
str(self.start_time),
str(self.end_time),
str(self.recent_seconds),
str(self.entities),
str(self.outcomes),
str(self.memory_types),
str(self.limit),
str(self.min_relevance),
]
key_string = "|".join(key_parts)
return hashlib.sha256(key_string.encode()).hexdigest()[:32]
@dataclass
class ScoredResult:
"""A retrieval result with relevance score."""
memory_id: UUID
memory_type: MemoryType
relevance_score: float
score_breakdown: dict[str, float] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class CacheEntry:
"""A cached retrieval result."""
results: list[ScoredResult]
created_at: datetime
ttl_seconds: float
query_key: str
def is_expired(self) -> bool:
"""Check if this cache entry has expired."""
age = (_utcnow() - self.created_at).total_seconds()
return age > self.ttl_seconds
class RelevanceScorer:
"""
Calculates relevance scores for retrieved memories.
Combines multiple signals:
- Vector similarity (if available)
- Temporal recency
- Entity match count
- Outcome preference
- Importance/confidence
"""
def __init__(
self,
vector_weight: float = 0.4,
recency_weight: float = 0.2,
entity_weight: float = 0.2,
outcome_weight: float = 0.1,
importance_weight: float = 0.1,
) -> None:
"""
Initialize the relevance scorer.
Args:
vector_weight: Weight for vector similarity (0-1)
recency_weight: Weight for temporal recency (0-1)
entity_weight: Weight for entity matches (0-1)
outcome_weight: Weight for outcome preference (0-1)
importance_weight: Weight for importance score (0-1)
"""
total = (
vector_weight
+ recency_weight
+ entity_weight
+ outcome_weight
+ importance_weight
)
# Normalize weights
self.vector_weight = vector_weight / total
self.recency_weight = recency_weight / total
self.entity_weight = entity_weight / total
self.outcome_weight = outcome_weight / total
self.importance_weight = importance_weight / total
def score(
self,
memory_id: UUID,
memory_type: MemoryType,
vector_similarity: float | None = None,
timestamp: datetime | None = None,
entity_match_count: int = 0,
entity_total: int = 1,
outcome: Outcome | None = None,
importance: float = 0.5,
preferred_outcomes: list[Outcome] | None = None,
) -> ScoredResult:
"""
Calculate a relevance score for a memory.
Args:
memory_id: ID of the memory
memory_type: Type of memory
vector_similarity: Similarity score from vector search (0-1)
timestamp: Timestamp of the memory
entity_match_count: Number of matching entities
entity_total: Total entities in query
outcome: Outcome of the memory
importance: Importance score of the memory (0-1)
preferred_outcomes: Outcomes to prefer
Returns:
Scored result with breakdown
"""
breakdown: dict[str, float] = {}
# Vector similarity score
if vector_similarity is not None:
breakdown["vector"] = vector_similarity
else:
breakdown["vector"] = 0.5 # Neutral if no vector
# Recency score (exponential decay)
if timestamp:
age_hours = (_utcnow() - timestamp).total_seconds() / 3600
# Decay with half-life of 24 hours
breakdown["recency"] = 2 ** (-age_hours / 24)
else:
breakdown["recency"] = 0.5
# Entity match score
if entity_total > 0:
breakdown["entity"] = entity_match_count / entity_total
else:
breakdown["entity"] = 1.0 # No entity filter = full score
# Outcome score
if preferred_outcomes and outcome:
breakdown["outcome"] = 1.0 if outcome in preferred_outcomes else 0.0
else:
breakdown["outcome"] = 0.5 # Neutral if no preference
# Importance score
breakdown["importance"] = importance
# Calculate weighted sum
total_score = (
breakdown["vector"] * self.vector_weight
+ breakdown["recency"] * self.recency_weight
+ breakdown["entity"] * self.entity_weight
+ breakdown["outcome"] * self.outcome_weight
+ breakdown["importance"] * self.importance_weight
)
return ScoredResult(
memory_id=memory_id,
memory_type=memory_type,
relevance_score=total_score,
score_breakdown=breakdown,
)
class RetrievalCache:
"""
In-memory cache for retrieval results.
Supports TTL-based expiration and LRU eviction.
"""
def __init__(
self,
max_entries: int = 1000,
default_ttl_seconds: float = 300,
) -> None:
"""
Initialize the cache.
Args:
max_entries: Maximum cache entries
default_ttl_seconds: Default TTL for entries
"""
self._cache: dict[str, CacheEntry] = {}
self._max_entries = max_entries
self._default_ttl = default_ttl_seconds
self._access_order: list[str] = []
logger.info(
f"Initialized RetrievalCache with max_entries={max_entries}, "
f"ttl={default_ttl_seconds}s"
)
def get(self, query_key: str) -> list[ScoredResult] | None:
"""
Get cached results for a query.
Args:
query_key: Cache key for the query
Returns:
Cached results or None if not found/expired
"""
if query_key not in self._cache:
return None
entry = self._cache[query_key]
if entry.is_expired():
del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return None
# Update access order (LRU)
if query_key in self._access_order:
self._access_order.remove(query_key)
self._access_order.append(query_key)
logger.debug(f"Cache hit for {query_key}")
return entry.results
def put(
self,
query_key: str,
results: list[ScoredResult],
ttl_seconds: float | None = None,
) -> None:
"""
Cache results for a query.
Args:
query_key: Cache key for the query
results: Results to cache
ttl_seconds: TTL for this entry (or default)
"""
# Evict if at capacity
while len(self._cache) >= self._max_entries and self._access_order:
oldest_key = self._access_order.pop(0)
if oldest_key in self._cache:
del self._cache[oldest_key]
entry = CacheEntry(
results=results,
created_at=_utcnow(),
ttl_seconds=ttl_seconds or self._default_ttl,
query_key=query_key,
)
self._cache[query_key] = entry
self._access_order.append(query_key)
logger.debug(f"Cached {len(results)} results for {query_key}")
def invalidate(self, query_key: str) -> bool:
"""
Invalidate a specific cache entry.
Args:
query_key: Cache key to invalidate
Returns:
True if entry was found and removed
"""
if query_key in self._cache:
del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return True
return False
def invalidate_by_memory(self, memory_id: UUID) -> int:
"""
Invalidate all cache entries containing a specific memory.
Args:
memory_id: Memory ID to invalidate
Returns:
Number of entries invalidated
"""
keys_to_remove = []
for key, entry in self._cache.items():
if any(r.memory_id == memory_id for r in entry.results):
keys_to_remove.append(key)
for key in keys_to_remove:
self.invalidate(key)
if keys_to_remove:
logger.debug(
f"Invalidated {len(keys_to_remove)} cache entries for {memory_id}"
)
return len(keys_to_remove)
def clear(self) -> int:
"""
Clear all cache entries.
Returns:
Number of entries cleared
"""
count = len(self._cache)
self._cache.clear()
self._access_order.clear()
logger.info(f"Cleared {count} cache entries")
return count
def get_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
expired_count = sum(1 for e in self._cache.values() if e.is_expired())
return {
"total_entries": len(self._cache),
"expired_entries": expired_count,
"max_entries": self._max_entries,
"default_ttl_seconds": self._default_ttl,
}
class RetrievalEngine:
"""
Hybrid retrieval engine for memory search.
Combines multiple index types for comprehensive retrieval:
- Vector search for semantic similarity
- Temporal index for time-based filtering
- Entity index for entity-based lookups
- Outcome index for success/failure filtering
Results are scored and ranked using relevance scoring.
"""
def __init__(
self,
indexer: MemoryIndexer | None = None,
scorer: RelevanceScorer | None = None,
cache: RetrievalCache | None = None,
enable_cache: bool = True,
) -> None:
"""
Initialize the retrieval engine.
Args:
indexer: Memory indexer (defaults to singleton)
scorer: Relevance scorer (defaults to new instance)
cache: Retrieval cache (defaults to new instance)
enable_cache: Whether to enable result caching
"""
self._indexer = indexer or get_memory_indexer()
self._scorer = scorer or RelevanceScorer()
self._cache = cache or RetrievalCache() if enable_cache else None
self._enable_cache = enable_cache
logger.info(f"Initialized RetrievalEngine with cache={enable_cache}")
async def retrieve(
self,
query: RetrievalQuery,
use_cache: bool = True,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve relevant memories using hybrid search.
Args:
query: Retrieval query parameters
use_cache: Whether to use cached results
Returns:
Retrieval result with scored items
"""
start_time = _utcnow()
# Check cache
cache_key = query.to_cache_key()
if use_cache and self._cache:
cached = self._cache.get(cache_key)
if cached:
latency = (_utcnow() - start_time).total_seconds() * 1000
return RetrievalResult(
items=cached,
total_count=len(cached),
query=query.query_text or "",
retrieval_type="cached",
latency_ms=latency,
metadata={"cache_hit": True},
)
# Collect candidates from each index
candidates: dict[UUID, dict[str, Any]] = {}
# Vector search
if query.use_vector and query.query_embedding:
vector_results = await self._indexer.vector_index.search(
query=query.query_embedding,
limit=query.limit * 3, # Get more for filtering
min_similarity=query.min_relevance,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for entry in vector_results:
if entry.memory_id not in candidates:
candidates[entry.memory_id] = {
"memory_type": entry.memory_type,
"sources": [],
}
candidates[entry.memory_id]["vector_similarity"] = entry.metadata.get(
"similarity", 0.5
)
candidates[entry.memory_id]["sources"].append("vector")
# Temporal search
if query.use_temporal and (
query.start_time or query.end_time or query.recent_seconds
):
temporal_results = await self._indexer.temporal_index.search(
query=None,
limit=query.limit * 3,
start_time=query.start_time,
end_time=query.end_time,
recent_seconds=query.recent_seconds,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for temporal_entry in temporal_results:
if temporal_entry.memory_id not in candidates:
candidates[temporal_entry.memory_id] = {
"memory_type": temporal_entry.memory_type,
"sources": [],
}
candidates[temporal_entry.memory_id]["timestamp"] = (
temporal_entry.timestamp
)
candidates[temporal_entry.memory_id]["sources"].append("temporal")
# Entity search
if query.use_entity and query.entities:
entity_results = await self._indexer.entity_index.search(
query=None,
limit=query.limit * 3,
entities=query.entities,
match_all=query.entity_match_all,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for entity_entry in entity_results:
if entity_entry.memory_id not in candidates:
candidates[entity_entry.memory_id] = {
"memory_type": entity_entry.memory_type,
"sources": [],
}
# Count entity matches
entity_count = candidates[entity_entry.memory_id].get(
"entity_match_count", 0
)
candidates[entity_entry.memory_id]["entity_match_count"] = (
entity_count + 1
)
candidates[entity_entry.memory_id]["sources"].append("entity")
# Outcome search
if query.use_outcome and query.outcomes:
outcome_results = await self._indexer.outcome_index.search(
query=None,
limit=query.limit * 3,
outcomes=query.outcomes,
memory_type=query.memory_types[0] if query.memory_types else None,
)
for outcome_entry in outcome_results:
if outcome_entry.memory_id not in candidates:
candidates[outcome_entry.memory_id] = {
"memory_type": outcome_entry.memory_type,
"sources": [],
}
candidates[outcome_entry.memory_id]["outcome"] = outcome_entry.outcome
candidates[outcome_entry.memory_id]["sources"].append("outcome")
# Score and rank candidates
scored_results: list[ScoredResult] = []
entity_total = len(query.entities) if query.entities else 1
for memory_id, data in candidates.items():
scored = self._scorer.score(
memory_id=memory_id,
memory_type=data["memory_type"],
vector_similarity=data.get("vector_similarity"),
timestamp=data.get("timestamp"),
entity_match_count=data.get("entity_match_count", 0),
entity_total=entity_total,
outcome=data.get("outcome"),
preferred_outcomes=query.outcomes,
)
scored.metadata["sources"] = data.get("sources", [])
# Filter by minimum relevance
if scored.relevance_score >= query.min_relevance:
scored_results.append(scored)
# Sort by relevance score
scored_results.sort(key=lambda x: x.relevance_score, reverse=True)
# Apply limit
final_results = scored_results[: query.limit]
# Cache results
if use_cache and self._cache and final_results:
self._cache.put(cache_key, final_results)
latency = (_utcnow() - start_time).total_seconds() * 1000
logger.info(
f"Retrieved {len(final_results)} results from {len(candidates)} candidates "
f"in {latency:.2f}ms"
)
return RetrievalResult(
items=final_results,
total_count=len(candidates),
query=query.query_text or "",
retrieval_type="hybrid",
latency_ms=latency,
metadata={
"cache_hit": False,
"candidates_count": len(candidates),
"filtered_count": len(scored_results),
},
)
async def retrieve_similar(
self,
embedding: list[float],
limit: int = 10,
min_similarity: float = 0.5,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve memories similar to a given embedding.
Args:
embedding: Query embedding
limit: Maximum results
min_similarity: Minimum similarity threshold
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
query_embedding=embedding,
limit=limit,
min_relevance=min_similarity,
memory_types=memory_types,
use_temporal=False,
use_entity=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_recent(
self,
hours: float = 24,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve recent memories.
Args:
hours: Number of hours to look back
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
recent_seconds=hours * 3600,
limit=limit,
memory_types=memory_types,
use_vector=False,
use_entity=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_by_entity(
self,
entity_type: str,
entity_value: str,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve memories by entity.
Args:
entity_type: Type of entity
entity_value: Entity value
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
entities=[(entity_type, entity_value)],
limit=limit,
memory_types=memory_types,
use_vector=False,
use_temporal=False,
use_outcome=False,
)
return await self.retrieve(query)
async def retrieve_successful(
self,
limit: int = 10,
memory_types: list[MemoryType] | None = None,
) -> RetrievalResult[ScoredResult]:
"""
Retrieve successful memories.
Args:
limit: Maximum results
memory_types: Filter by memory types
Returns:
Retrieval result with scored items
"""
query = RetrievalQuery(
outcomes=[Outcome.SUCCESS],
limit=limit,
memory_types=memory_types,
use_vector=False,
use_temporal=False,
use_entity=False,
)
return await self.retrieve(query)
def invalidate_cache(self) -> int:
"""
Invalidate all cached results.
Returns:
Number of entries invalidated
"""
if self._cache:
return self._cache.clear()
return 0
def invalidate_cache_for_memory(self, memory_id: UUID) -> int:
"""
Invalidate cache entries containing a specific memory.
Args:
memory_id: Memory ID to invalidate
Returns:
Number of entries invalidated
"""
if self._cache:
return self._cache.invalidate_by_memory(memory_id)
return 0
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics."""
if self._cache:
return self._cache.get_stats()
return {"enabled": False}
# Singleton retrieval engine instance
_engine: RetrievalEngine | None = None
def get_retrieval_engine() -> RetrievalEngine:
"""Get the singleton retrieval engine instance."""
global _engine
if _engine is None:
_engine = RetrievalEngine()
return _engine

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/indexing/__init__.py
"""Unit tests for memory indexing."""

View File

@@ -0,0 +1,497 @@
# tests/unit/services/memory/indexing/test_index.py
"""Unit tests for memory indexing."""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import pytest
from app.services.memory.indexing.index import (
EntityIndex,
MemoryIndexer,
OutcomeIndex,
TemporalIndex,
VectorIndex,
get_memory_indexer,
)
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
def _utcnow() -> datetime:
"""Get current UTC time."""
return datetime.now(UTC)
def make_episode(
embedding: list[float] | None = None,
outcome: Outcome = Outcome.SUCCESS,
occurred_at: datetime | None = None,
) -> Episode:
"""Create a test episode."""
return Episode(
id=uuid4(),
project_id=uuid4(),
agent_instance_id=uuid4(),
agent_type_id=uuid4(),
session_id="test-session",
task_type="test_task",
task_description="Test task description",
actions=[{"action": "test"}],
context_summary="Test context",
outcome=outcome,
outcome_details="Test outcome",
duration_seconds=10.0,
tokens_used=100,
lessons_learned=["lesson1"],
importance_score=0.8,
embedding=embedding,
occurred_at=occurred_at or _utcnow(),
created_at=_utcnow(),
updated_at=_utcnow(),
)
def make_fact(
embedding: list[float] | None = None,
subject: str = "test_subject",
predicate: str = "has_property",
obj: str = "test_value",
) -> Fact:
"""Create a test fact."""
return Fact(
id=uuid4(),
project_id=uuid4(),
subject=subject,
predicate=predicate,
object=obj,
confidence=0.9,
source_episode_ids=[uuid4()],
first_learned=_utcnow(),
last_reinforced=_utcnow(),
reinforcement_count=1,
embedding=embedding,
created_at=_utcnow(),
updated_at=_utcnow(),
)
def make_procedure(
embedding: list[float] | None = None,
success_count: int = 8,
failure_count: int = 2,
) -> Procedure:
"""Create a test procedure."""
return Procedure(
id=uuid4(),
project_id=uuid4(),
agent_type_id=uuid4(),
name="test_procedure",
trigger_pattern="test.*",
steps=[{"step": 1, "action": "test"}],
success_count=success_count,
failure_count=failure_count,
last_used=_utcnow(),
embedding=embedding,
created_at=_utcnow(),
updated_at=_utcnow(),
)
class TestVectorIndex:
"""Tests for VectorIndex."""
@pytest.fixture
def index(self) -> VectorIndex[Episode]:
"""Create a vector index."""
return VectorIndex[Episode](dimension=4)
@pytest.mark.asyncio
async def test_add_item(self, index: VectorIndex[Episode]) -> None:
"""Test adding an item to the index."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
entry = await index.add(episode)
assert entry.memory_id == episode.id
assert entry.memory_type == MemoryType.EPISODIC
assert entry.dimension == 4
assert await index.count() == 1
@pytest.mark.asyncio
async def test_remove_item(self, index: VectorIndex[Episode]) -> None:
"""Test removing an item from the index."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await index.add(episode)
result = await index.remove(episode.id)
assert result is True
assert await index.count() == 0
@pytest.mark.asyncio
async def test_remove_nonexistent(self, index: VectorIndex[Episode]) -> None:
"""Test removing a nonexistent item."""
result = await index.remove(uuid4())
assert result is False
@pytest.mark.asyncio
async def test_search_similar(self, index: VectorIndex[Episode]) -> None:
"""Test searching for similar items."""
# Add items with different embeddings
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
await index.add(e1)
await index.add(e2)
await index.add(e3)
# Search for similar to first
results = await index.search([1.0, 0.0, 0.0, 0.0], limit=2)
assert len(results) == 2
# First result should be most similar
assert results[0].memory_id == e1.id
@pytest.mark.asyncio
async def test_search_min_similarity(self, index: VectorIndex[Episode]) -> None:
"""Test minimum similarity threshold."""
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0]) # Orthogonal
await index.add(e1)
await index.add(e2)
# Search with high threshold
results = await index.search([1.0, 0.0, 0.0, 0.0], min_similarity=0.9)
assert len(results) == 1
assert results[0].memory_id == e1.id
@pytest.mark.asyncio
async def test_search_empty_query(self, index: VectorIndex[Episode]) -> None:
"""Test search with empty query."""
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await index.add(e1)
results = await index.search([], limit=10)
assert len(results) == 0
@pytest.mark.asyncio
async def test_clear(self, index: VectorIndex[Episode]) -> None:
"""Test clearing the index."""
await index.add(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
await index.add(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
count = await index.clear()
assert count == 2
assert await index.count() == 0
class TestTemporalIndex:
"""Tests for TemporalIndex."""
@pytest.fixture
def index(self) -> TemporalIndex[Episode]:
"""Create a temporal index."""
return TemporalIndex[Episode]()
@pytest.mark.asyncio
async def test_add_item(self, index: TemporalIndex[Episode]) -> None:
"""Test adding an item."""
episode = make_episode()
entry = await index.add(episode)
assert entry.memory_id == episode.id
assert await index.count() == 1
@pytest.mark.asyncio
async def test_search_by_time_range(self, index: TemporalIndex[Episode]) -> None:
"""Test searching by time range."""
now = _utcnow()
old = make_episode(occurred_at=now - timedelta(hours=2))
recent = make_episode(occurred_at=now - timedelta(hours=1))
newest = make_episode(occurred_at=now)
await index.add(old)
await index.add(recent)
await index.add(newest)
# Search last hour
results = await index.search(
query=None,
start_time=now - timedelta(hours=1, minutes=30),
end_time=now,
)
assert len(results) == 2
@pytest.mark.asyncio
async def test_search_recent(self, index: TemporalIndex[Episode]) -> None:
"""Test searching for recent items."""
now = _utcnow()
old = make_episode(occurred_at=now - timedelta(hours=2))
recent = make_episode(occurred_at=now - timedelta(minutes=30))
await index.add(old)
await index.add(recent)
# Search last hour (3600 seconds)
results = await index.search(query=None, recent_seconds=3600)
assert len(results) == 1
assert results[0].memory_id == recent.id
@pytest.mark.asyncio
async def test_search_order(self, index: TemporalIndex[Episode]) -> None:
"""Test result ordering."""
now = _utcnow()
e1 = make_episode(occurred_at=now - timedelta(hours=2))
e2 = make_episode(occurred_at=now - timedelta(hours=1))
e3 = make_episode(occurred_at=now)
await index.add(e1)
await index.add(e2)
await index.add(e3)
# Descending order (newest first)
results_desc = await index.search(query=None, order="desc", limit=10)
assert results_desc[0].memory_id == e3.id
# Ascending order (oldest first)
results_asc = await index.search(query=None, order="asc", limit=10)
assert results_asc[0].memory_id == e1.id
class TestEntityIndex:
"""Tests for EntityIndex."""
@pytest.fixture
def index(self) -> EntityIndex[Fact]:
"""Create an entity index."""
return EntityIndex[Fact]()
@pytest.mark.asyncio
async def test_add_item(self, index: EntityIndex[Fact]) -> None:
"""Test adding an item."""
fact = make_fact(subject="user", obj="admin")
entry = await index.add(fact)
assert entry.memory_id == fact.id
assert await index.count() == 1
@pytest.mark.asyncio
async def test_search_by_entity(self, index: EntityIndex[Fact]) -> None:
"""Test searching by entity."""
f1 = make_fact(subject="user", obj="admin")
f2 = make_fact(subject="system", obj="config")
await index.add(f1)
await index.add(f2)
results = await index.search(
query=None,
entity_type="subject",
entity_value="user",
)
assert len(results) == 1
assert results[0].memory_id == f1.id
@pytest.mark.asyncio
async def test_search_multiple_entities(self, index: EntityIndex[Fact]) -> None:
"""Test searching with multiple entities."""
f1 = make_fact(subject="user", obj="admin")
f2 = make_fact(subject="user", obj="guest")
await index.add(f1)
await index.add(f2)
# Search for facts about "user" subject
results = await index.search(
query=None,
entities=[("subject", "user")],
)
assert len(results) == 2
@pytest.mark.asyncio
async def test_search_match_all(self, index: EntityIndex[Fact]) -> None:
"""Test matching all entities."""
f1 = make_fact(subject="user", obj="admin")
f2 = make_fact(subject="user", obj="guest")
await index.add(f1)
await index.add(f2)
# Search for user+admin (match all)
results = await index.search(
query=None,
entities=[("subject", "user"), ("object", "admin")],
match_all=True,
)
assert len(results) == 1
assert results[0].memory_id == f1.id
@pytest.mark.asyncio
async def test_get_entities(self, index: EntityIndex[Fact]) -> None:
"""Test getting entities for a memory."""
fact = make_fact(subject="user", obj="admin")
await index.add(fact)
entities = await index.get_entities(fact.id)
assert ("subject", "user") in entities
assert ("object", "admin") in entities
class TestOutcomeIndex:
"""Tests for OutcomeIndex."""
@pytest.fixture
def index(self) -> OutcomeIndex[Episode]:
"""Create an outcome index."""
return OutcomeIndex[Episode]()
@pytest.mark.asyncio
async def test_add_item(self, index: OutcomeIndex[Episode]) -> None:
"""Test adding an item."""
episode = make_episode(outcome=Outcome.SUCCESS)
entry = await index.add(episode)
assert entry.memory_id == episode.id
assert entry.outcome == Outcome.SUCCESS
assert await index.count() == 1
@pytest.mark.asyncio
async def test_search_by_outcome(self, index: OutcomeIndex[Episode]) -> None:
"""Test searching by outcome."""
success = make_episode(outcome=Outcome.SUCCESS)
failure = make_episode(outcome=Outcome.FAILURE)
await index.add(success)
await index.add(failure)
results = await index.search(query=None, outcome=Outcome.SUCCESS)
assert len(results) == 1
assert results[0].memory_id == success.id
@pytest.mark.asyncio
async def test_search_multiple_outcomes(self, index: OutcomeIndex[Episode]) -> None:
"""Test searching with multiple outcomes."""
success = make_episode(outcome=Outcome.SUCCESS)
partial = make_episode(outcome=Outcome.PARTIAL)
failure = make_episode(outcome=Outcome.FAILURE)
await index.add(success)
await index.add(partial)
await index.add(failure)
results = await index.search(
query=None,
outcomes=[Outcome.SUCCESS, Outcome.PARTIAL],
)
assert len(results) == 2
@pytest.mark.asyncio
async def test_get_outcome_stats(self, index: OutcomeIndex[Episode]) -> None:
"""Test getting outcome statistics."""
await index.add(make_episode(outcome=Outcome.SUCCESS))
await index.add(make_episode(outcome=Outcome.SUCCESS))
await index.add(make_episode(outcome=Outcome.FAILURE))
stats = await index.get_outcome_stats()
assert stats[Outcome.SUCCESS] == 2
assert stats[Outcome.FAILURE] == 1
assert stats[Outcome.PARTIAL] == 0
class TestMemoryIndexer:
"""Tests for MemoryIndexer."""
@pytest.fixture
def indexer(self) -> MemoryIndexer:
"""Create a memory indexer."""
return MemoryIndexer()
@pytest.mark.asyncio
async def test_index_episode(self, indexer: MemoryIndexer) -> None:
"""Test indexing an episode."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
results = await indexer.index(episode)
assert "vector" in results
assert "temporal" in results
assert "entity" in results
assert "outcome" in results
@pytest.mark.asyncio
async def test_index_fact(self, indexer: MemoryIndexer) -> None:
"""Test indexing a fact."""
fact = make_fact(embedding=[1.0, 0.0, 0.0, 0.0])
results = await indexer.index(fact)
# Facts don't have outcomes
assert "vector" in results
assert "temporal" in results
assert "entity" in results
assert "outcome" not in results
@pytest.mark.asyncio
async def test_remove_from_all(self, indexer: MemoryIndexer) -> None:
"""Test removing from all indices."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await indexer.index(episode)
results = await indexer.remove(episode.id)
assert results["vector"] is True
assert results["temporal"] is True
assert results["entity"] is True
assert results["outcome"] is True
@pytest.mark.asyncio
async def test_clear_all(self, indexer: MemoryIndexer) -> None:
"""Test clearing all indices."""
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
await indexer.index(make_episode(embedding=[0.0, 1.0, 0.0, 0.0]))
counts = await indexer.clear_all()
assert counts["vector"] == 2
assert counts["temporal"] == 2
@pytest.mark.asyncio
async def test_get_stats(self, indexer: MemoryIndexer) -> None:
"""Test getting index statistics."""
await indexer.index(make_episode(embedding=[1.0, 0.0, 0.0, 0.0]))
stats = await indexer.get_stats()
assert stats["vector"] == 1
assert stats["temporal"] == 1
assert stats["entity"] == 1
assert stats["outcome"] == 1
class TestGetMemoryIndexer:
"""Tests for singleton getter."""
def test_returns_instance(self) -> None:
"""Test that getter returns instance."""
indexer = get_memory_indexer()
assert indexer is not None
assert isinstance(indexer, MemoryIndexer)
def test_returns_same_instance(self) -> None:
"""Test that getter returns same instance."""
indexer1 = get_memory_indexer()
indexer2 = get_memory_indexer()
assert indexer1 is indexer2

View File

@@ -0,0 +1,450 @@
# tests/unit/services/memory/indexing/test_retrieval.py
"""Unit tests for memory retrieval."""
from datetime import UTC, datetime, timedelta
from uuid import uuid4
import pytest
from app.services.memory.indexing.index import MemoryIndexer
from app.services.memory.indexing.retrieval import (
RelevanceScorer,
RetrievalCache,
RetrievalEngine,
RetrievalQuery,
ScoredResult,
get_retrieval_engine,
)
from app.services.memory.types import Episode, MemoryType, Outcome
def _utcnow() -> datetime:
"""Get current UTC time."""
return datetime.now(UTC)
def make_episode(
embedding: list[float] | None = None,
outcome: Outcome = Outcome.SUCCESS,
occurred_at: datetime | None = None,
task_type: str = "test_task",
) -> Episode:
"""Create a test episode."""
return Episode(
id=uuid4(),
project_id=uuid4(),
agent_instance_id=uuid4(),
agent_type_id=uuid4(),
session_id="test-session",
task_type=task_type,
task_description="Test task description",
actions=[{"action": "test"}],
context_summary="Test context",
outcome=outcome,
outcome_details="Test outcome",
duration_seconds=10.0,
tokens_used=100,
lessons_learned=["lesson1"],
importance_score=0.8,
embedding=embedding,
occurred_at=occurred_at or _utcnow(),
created_at=_utcnow(),
updated_at=_utcnow(),
)
class TestRetrievalQuery:
"""Tests for RetrievalQuery."""
def test_default_values(self) -> None:
"""Test default query values."""
query = RetrievalQuery()
assert query.query_text is None
assert query.limit == 10
assert query.min_relevance == 0.0
assert query.use_vector is True
assert query.use_temporal is True
def test_cache_key_generation(self) -> None:
"""Test cache key generation."""
query1 = RetrievalQuery(query_text="test", limit=10)
query2 = RetrievalQuery(query_text="test", limit=10)
query3 = RetrievalQuery(query_text="different", limit=10)
# Same queries should have same key
assert query1.to_cache_key() == query2.to_cache_key()
# Different queries should have different keys
assert query1.to_cache_key() != query3.to_cache_key()
class TestScoredResult:
"""Tests for ScoredResult."""
def test_creation(self) -> None:
"""Test creating a scored result."""
result = ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.85,
score_breakdown={"vector": 0.9, "recency": 0.8},
)
assert result.relevance_score == 0.85
assert result.score_breakdown["vector"] == 0.9
class TestRelevanceScorer:
"""Tests for RelevanceScorer."""
@pytest.fixture
def scorer(self) -> RelevanceScorer:
"""Create a relevance scorer."""
return RelevanceScorer()
def test_score_with_vector(self, scorer: RelevanceScorer) -> None:
"""Test scoring with vector similarity."""
result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
vector_similarity=0.9,
)
assert result.relevance_score > 0
assert result.score_breakdown["vector"] == 0.9
def test_score_with_recency(self, scorer: RelevanceScorer) -> None:
"""Test scoring with recency."""
recent_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
timestamp=_utcnow(),
)
old_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
timestamp=_utcnow() - timedelta(days=7),
)
# Recent should have higher recency score
assert (
recent_result.score_breakdown["recency"]
> old_result.score_breakdown["recency"]
)
def test_score_with_outcome_preference(self, scorer: RelevanceScorer) -> None:
"""Test scoring with outcome preference."""
success_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
outcome=Outcome.SUCCESS,
preferred_outcomes=[Outcome.SUCCESS],
)
failure_result = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
outcome=Outcome.FAILURE,
preferred_outcomes=[Outcome.SUCCESS],
)
assert success_result.score_breakdown["outcome"] == 1.0
assert failure_result.score_breakdown["outcome"] == 0.0
def test_score_with_entity_match(self, scorer: RelevanceScorer) -> None:
"""Test scoring with entity matches."""
full_match = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
entity_match_count=3,
entity_total=3,
)
partial_match = scorer.score(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
entity_match_count=1,
entity_total=3,
)
assert (
full_match.score_breakdown["entity"]
> partial_match.score_breakdown["entity"]
)
class TestRetrievalCache:
"""Tests for RetrievalCache."""
@pytest.fixture
def cache(self) -> RetrievalCache:
"""Create a retrieval cache."""
return RetrievalCache(max_entries=10, default_ttl_seconds=60)
def test_put_and_get(self, cache: RetrievalCache) -> None:
"""Test putting and getting from cache."""
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("test_key", results)
cached = cache.get("test_key")
assert cached is not None
assert len(cached) == 1
def test_get_nonexistent(self, cache: RetrievalCache) -> None:
"""Test getting nonexistent entry."""
result = cache.get("nonexistent")
assert result is None
def test_lru_eviction(self) -> None:
"""Test LRU eviction when at capacity."""
cache = RetrievalCache(max_entries=2, default_ttl_seconds=60)
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("key1", results)
cache.put("key2", results)
cache.put("key3", results) # Should evict key1
assert cache.get("key1") is None
assert cache.get("key2") is not None
assert cache.get("key3") is not None
def test_invalidate(self, cache: RetrievalCache) -> None:
"""Test invalidating a cache entry."""
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("test_key", results)
removed = cache.invalidate("test_key")
assert removed is True
assert cache.get("test_key") is None
def test_invalidate_by_memory(self, cache: RetrievalCache) -> None:
"""Test invalidating by memory ID."""
memory_id = uuid4()
results = [
ScoredResult(
memory_id=memory_id,
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("key1", results)
cache.put("key2", results)
count = cache.invalidate_by_memory(memory_id)
assert count == 2
assert cache.get("key1") is None
assert cache.get("key2") is None
def test_clear(self, cache: RetrievalCache) -> None:
"""Test clearing the cache."""
results = [
ScoredResult(
memory_id=uuid4(),
memory_type=MemoryType.EPISODIC,
relevance_score=0.8,
)
]
cache.put("key1", results)
cache.put("key2", results)
count = cache.clear()
assert count == 2
assert cache.get("key1") is None
def test_get_stats(self, cache: RetrievalCache) -> None:
"""Test getting cache statistics."""
stats = cache.get_stats()
assert "total_entries" in stats
assert "max_entries" in stats
assert stats["max_entries"] == 10
class TestRetrievalEngine:
"""Tests for RetrievalEngine."""
@pytest.fixture
def indexer(self) -> MemoryIndexer:
"""Create a memory indexer."""
return MemoryIndexer()
@pytest.fixture
def engine(self, indexer: MemoryIndexer) -> RetrievalEngine:
"""Create a retrieval engine."""
return RetrievalEngine(indexer=indexer, enable_cache=True)
@pytest.mark.asyncio
async def test_retrieve_by_vector(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval by vector similarity."""
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
e2 = make_episode(embedding=[0.9, 0.1, 0.0, 0.0])
e3 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
await indexer.index(e1)
await indexer.index(e2)
await indexer.index(e3)
query = RetrievalQuery(
query_embedding=[1.0, 0.0, 0.0, 0.0],
limit=2,
use_temporal=False,
use_entity=False,
use_outcome=False,
)
result = await engine.retrieve(query)
assert len(result.items) > 0
assert result.retrieval_type == "hybrid"
@pytest.mark.asyncio
async def test_retrieve_recent(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval of recent items."""
now = _utcnow()
old = make_episode(occurred_at=now - timedelta(hours=2))
recent = make_episode(occurred_at=now - timedelta(minutes=30))
await indexer.index(old)
await indexer.index(recent)
result = await engine.retrieve_recent(hours=1)
assert len(result.items) == 1
@pytest.mark.asyncio
async def test_retrieve_by_entity(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval by entity."""
e1 = make_episode(task_type="deploy")
e2 = make_episode(task_type="test")
await indexer.index(e1)
await indexer.index(e2)
result = await engine.retrieve_by_entity("task_type", "deploy")
assert len(result.items) == 1
@pytest.mark.asyncio
async def test_retrieve_successful(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieval of successful items."""
success = make_episode(outcome=Outcome.SUCCESS)
failure = make_episode(outcome=Outcome.FAILURE)
await indexer.index(success)
await indexer.index(failure)
result = await engine.retrieve_successful()
assert len(result.items) == 1
# Check outcome index was used
assert result.items[0].memory_id == success.id
@pytest.mark.asyncio
async def test_retrieve_with_cache(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test that retrieval uses cache."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await indexer.index(episode)
query = RetrievalQuery(
query_embedding=[1.0, 0.0, 0.0, 0.0],
limit=10,
)
# First retrieval
result1 = await engine.retrieve(query)
assert result1.metadata.get("cache_hit") is False
# Second retrieval should be cached
result2 = await engine.retrieve(query)
assert result2.metadata.get("cache_hit") is True
@pytest.mark.asyncio
async def test_invalidate_cache(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test cache invalidation."""
episode = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
await indexer.index(episode)
query = RetrievalQuery(query_embedding=[1.0, 0.0, 0.0, 0.0])
await engine.retrieve(query)
count = engine.invalidate_cache()
assert count > 0
@pytest.mark.asyncio
async def test_retrieve_similar(
self, engine: RetrievalEngine, indexer: MemoryIndexer
) -> None:
"""Test retrieve_similar convenience method."""
e1 = make_episode(embedding=[1.0, 0.0, 0.0, 0.0])
e2 = make_episode(embedding=[0.0, 1.0, 0.0, 0.0])
await indexer.index(e1)
await indexer.index(e2)
result = await engine.retrieve_similar(
embedding=[1.0, 0.0, 0.0, 0.0],
limit=1,
)
assert len(result.items) == 1
def test_get_cache_stats(self, engine: RetrievalEngine) -> None:
"""Test getting cache statistics."""
stats = engine.get_cache_stats()
assert "total_entries" in stats
class TestGetRetrievalEngine:
"""Tests for singleton getter."""
def test_returns_instance(self) -> None:
"""Test that getter returns instance."""
engine = get_retrieval_engine()
assert engine is not None
assert isinstance(engine, RetrievalEngine)
def test_returns_same_instance(self) -> None:
"""Test that getter returns same instance."""
engine1 = get_retrieval_engine()
engine2 = get_retrieval_engine()
assert engine1 is engine2