feat(memory): implement memory indexing and retrieval engine (#94)
Add comprehensive indexing and retrieval system for memory search: - VectorIndex for semantic similarity search using cosine similarity - TemporalIndex for time-based queries with range and recency support - EntityIndex for entity-based lookups with multi-entity intersection - OutcomeIndex for success/failure filtering on episodes - MemoryIndexer as unified interface for all index types - RetrievalEngine with hybrid search combining all indices - RelevanceScorer for multi-signal relevance scoring - RetrievalCache for LRU caching of search results 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
851
backend/app/services/memory/indexing/index.py
Normal file
851
backend/app/services/memory/indexing/index.py
Normal file
@@ -0,0 +1,851 @@
|
||||
# app/services/memory/indexing/index.py
|
||||
"""
|
||||
Memory Indexing.
|
||||
|
||||
Provides multiple indexing strategies for efficient memory retrieval:
|
||||
- Vector embeddings for semantic search
|
||||
- Temporal index for time-based queries
|
||||
- Entity index for entity-based lookups
|
||||
- Outcome index for success/failure filtering
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", Episode, Fact, Procedure)
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time as timezone-aware datetime."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexEntry:
|
||||
"""A single entry in an index."""
|
||||
|
||||
memory_id: UUID
|
||||
memory_type: MemoryType
|
||||
indexed_at: datetime = field(default_factory=_utcnow)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorIndexEntry(IndexEntry):
|
||||
"""An entry with vector embedding."""
|
||||
|
||||
embedding: list[float] = field(default_factory=list)
|
||||
dimension: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set dimension from embedding."""
|
||||
if self.embedding:
|
||||
self.dimension = len(self.embedding)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemporalIndexEntry(IndexEntry):
|
||||
"""An entry indexed by time."""
|
||||
|
||||
timestamp: datetime = field(default_factory=_utcnow)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityIndexEntry(IndexEntry):
|
||||
"""An entry indexed by entity."""
|
||||
|
||||
entity_type: str = ""
|
||||
entity_value: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutcomeIndexEntry(IndexEntry):
|
||||
"""An entry indexed by outcome."""
|
||||
|
||||
outcome: Outcome = Outcome.SUCCESS
|
||||
|
||||
|
||||
class MemoryIndex[T](ABC):
|
||||
"""Abstract base class for memory indices."""
|
||||
|
||||
@abstractmethod
|
||||
async def add(self, item: T) -> IndexEntry:
|
||||
"""Add an item to the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
**kwargs: Any,
|
||||
) -> list[IndexEntry]:
|
||||
"""Search the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
...
|
||||
|
||||
|
||||
class VectorIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Vector-based index using embeddings for semantic similarity search.
|
||||
|
||||
Uses cosine similarity for matching.
|
||||
"""
|
||||
|
||||
def __init__(self, dimension: int = 1536) -> None:
|
||||
"""
|
||||
Initialize the vector index.
|
||||
|
||||
Args:
|
||||
dimension: Embedding dimension (default 1536 for OpenAI)
|
||||
"""
|
||||
self._dimension = dimension
|
||||
self._entries: dict[UUID, VectorIndexEntry] = {}
|
||||
logger.info(f"Initialized VectorIndex with dimension={dimension}")
|
||||
|
||||
async def add(self, item: T) -> VectorIndexEntry:
|
||||
"""
|
||||
Add an item to the vector index.
|
||||
|
||||
Args:
|
||||
item: Memory item with embedding
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
embedding = getattr(item, "embedding", None) or []
|
||||
|
||||
entry = VectorIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
embedding=embedding,
|
||||
dimension=len(embedding),
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
logger.debug(f"Added {item.id} to vector index")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the vector index."""
|
||||
if memory_id in self._entries:
|
||||
del self._entries[memory_id]
|
||||
logger.debug(f"Removed {memory_id} from vector index")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
min_similarity: float = 0.0,
|
||||
**kwargs: Any,
|
||||
) -> list[VectorIndexEntry]:
|
||||
"""
|
||||
Search for similar items using vector similarity.
|
||||
|
||||
Args:
|
||||
query: Query embedding vector
|
||||
limit: Maximum results to return
|
||||
min_similarity: Minimum similarity threshold (0-1)
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries sorted by similarity
|
||||
"""
|
||||
if not isinstance(query, list) or not query:
|
||||
return []
|
||||
|
||||
results: list[tuple[float, VectorIndexEntry]] = []
|
||||
|
||||
for entry in self._entries.values():
|
||||
if not entry.embedding:
|
||||
continue
|
||||
|
||||
similarity = self._cosine_similarity(query, entry.embedding)
|
||||
if similarity >= min_similarity:
|
||||
results.append((similarity, entry))
|
||||
|
||||
# Sort by similarity descending
|
||||
results.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
if memory_type:
|
||||
results = [(s, e) for s, e in results if e.memory_type == memory_type]
|
||||
|
||||
# Store similarity in metadata for the returned entries
|
||||
output = []
|
||||
for similarity, entry in results[:limit]:
|
||||
entry.metadata["similarity"] = similarity
|
||||
output.append(entry)
|
||||
|
||||
logger.debug(f"Vector search returned {len(output)} results")
|
||||
return output
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
logger.info(f"Cleared {count} entries from vector index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
|
||||
"""Calculate cosine similarity between two vectors."""
|
||||
if len(a) != len(b) or len(a) == 0:
|
||||
return 0.0
|
||||
|
||||
dot_product = sum(x * y for x, y in zip(a, b, strict=True))
|
||||
norm_a = sum(x * x for x in a) ** 0.5
|
||||
norm_b = sum(x * x for x in b) ** 0.5
|
||||
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
|
||||
return dot_product / (norm_a * norm_b)
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class TemporalIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Time-based index for efficient temporal queries.
|
||||
|
||||
Supports:
|
||||
- Range queries (between timestamps)
|
||||
- Recent items (within last N seconds/hours/days)
|
||||
- Oldest/newest sorting
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the temporal index."""
|
||||
self._entries: dict[UUID, TemporalIndexEntry] = {}
|
||||
# Sorted list for efficient range queries
|
||||
self._sorted_entries: list[tuple[datetime, UUID]] = []
|
||||
logger.info("Initialized TemporalIndex")
|
||||
|
||||
async def add(self, item: T) -> TemporalIndexEntry:
|
||||
"""
|
||||
Add an item to the temporal index.
|
||||
|
||||
Args:
|
||||
item: Memory item with timestamp
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
# Get timestamp from various possible fields
|
||||
timestamp = self._get_timestamp(item)
|
||||
|
||||
entry = TemporalIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
self._insert_sorted(timestamp, item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to temporal index at {timestamp}")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the temporal index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
self._entries.pop(memory_id)
|
||||
self._sorted_entries = [
|
||||
(ts, mid) for ts, mid in self._sorted_entries if mid != memory_id
|
||||
]
|
||||
|
||||
logger.debug(f"Removed {memory_id} from temporal index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
recent_seconds: float | None = None,
|
||||
order: str = "desc",
|
||||
**kwargs: Any,
|
||||
) -> list[TemporalIndexEntry]:
|
||||
"""
|
||||
Search for items by time.
|
||||
|
||||
Args:
|
||||
query: Ignored for temporal search
|
||||
limit: Maximum results to return
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
recent_seconds: Get items from last N seconds
|
||||
order: Sort order ("asc" or "desc")
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries sorted by time
|
||||
"""
|
||||
if recent_seconds is not None:
|
||||
start_time = _utcnow() - timedelta(seconds=recent_seconds)
|
||||
end_time = _utcnow()
|
||||
|
||||
# Filter by time range
|
||||
results: list[TemporalIndexEntry] = []
|
||||
for entry in self._entries.values():
|
||||
if start_time and entry.timestamp < start_time:
|
||||
continue
|
||||
if end_time and entry.timestamp > end_time:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
if memory_type:
|
||||
results = [e for e in results if e.memory_type == memory_type]
|
||||
|
||||
# Sort by timestamp
|
||||
results.sort(key=lambda e: e.timestamp, reverse=(order == "desc"))
|
||||
|
||||
logger.debug(f"Temporal search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._sorted_entries.clear()
|
||||
logger.info(f"Cleared {count} entries from temporal index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None:
|
||||
"""Insert entry maintaining sorted order."""
|
||||
# Binary search insert for efficiency
|
||||
low, high = 0, len(self._sorted_entries)
|
||||
while low < high:
|
||||
mid = (low + high) // 2
|
||||
if self._sorted_entries[mid][0] < timestamp:
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid
|
||||
self._sorted_entries.insert(low, (timestamp, memory_id))
|
||||
|
||||
def _get_timestamp(self, item: T) -> datetime:
|
||||
"""Get the relevant timestamp for an item."""
|
||||
if hasattr(item, "occurred_at"):
|
||||
return item.occurred_at
|
||||
if hasattr(item, "first_learned"):
|
||||
return item.first_learned
|
||||
if hasattr(item, "last_used") and item.last_used:
|
||||
return item.last_used
|
||||
if hasattr(item, "created_at"):
|
||||
return item.created_at
|
||||
return _utcnow()
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class EntityIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Entity-based index for lookups by entities mentioned in memories.
|
||||
|
||||
Supports:
|
||||
- Single entity lookup
|
||||
- Multi-entity intersection
|
||||
- Entity type filtering
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the entity index."""
|
||||
# Main storage
|
||||
self._entries: dict[UUID, EntityIndexEntry] = {}
|
||||
# Inverted index: entity -> set of memory IDs
|
||||
self._entity_to_memories: dict[str, set[UUID]] = {}
|
||||
# Memory to entities mapping
|
||||
self._memory_to_entities: dict[UUID, set[str]] = {}
|
||||
logger.info("Initialized EntityIndex")
|
||||
|
||||
async def add(self, item: T) -> EntityIndexEntry:
|
||||
"""
|
||||
Add an item to the entity index.
|
||||
|
||||
Args:
|
||||
item: Memory item with entity information
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
entities = self._extract_entities(item)
|
||||
|
||||
# Create entry for the primary entity (or first one)
|
||||
primary_entity = entities[0] if entities else ("unknown", "unknown")
|
||||
|
||||
entry = EntityIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
entity_type=primary_entity[0],
|
||||
entity_value=primary_entity[1],
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
|
||||
# Update inverted indices
|
||||
entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities}
|
||||
self._memory_to_entities[item.id] = entity_keys
|
||||
|
||||
for entity_key in entity_keys:
|
||||
if entity_key not in self._entity_to_memories:
|
||||
self._entity_to_memories[entity_key] = set()
|
||||
self._entity_to_memories[entity_key].add(item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to entity index with {len(entities)} entities")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the entity index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
# Remove from inverted index
|
||||
if memory_id in self._memory_to_entities:
|
||||
for entity_key in self._memory_to_entities[memory_id]:
|
||||
if entity_key in self._entity_to_memories:
|
||||
self._entity_to_memories[entity_key].discard(memory_id)
|
||||
if not self._entity_to_memories[entity_key]:
|
||||
del self._entity_to_memories[entity_key]
|
||||
del self._memory_to_entities[memory_id]
|
||||
|
||||
del self._entries[memory_id]
|
||||
logger.debug(f"Removed {memory_id} from entity index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
entity_type: str | None = None,
|
||||
entity_value: str | None = None,
|
||||
entities: list[tuple[str, str]] | None = None,
|
||||
match_all: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> list[EntityIndexEntry]:
|
||||
"""
|
||||
Search for items by entity.
|
||||
|
||||
Args:
|
||||
query: Entity value to search (if entity_type not specified)
|
||||
limit: Maximum results to return
|
||||
entity_type: Type of entity to filter
|
||||
entity_value: Specific entity value
|
||||
entities: List of (type, value) tuples to match
|
||||
match_all: If True, require all entities to match
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries
|
||||
"""
|
||||
matching_ids: set[UUID] | None = None
|
||||
|
||||
# Handle single entity query
|
||||
if entity_type and entity_value:
|
||||
entities = [(entity_type, entity_value)]
|
||||
elif entity_value is None and isinstance(query, str):
|
||||
# Search across all entity types
|
||||
entity_value = query
|
||||
|
||||
if entities:
|
||||
for etype, evalue in entities:
|
||||
entity_key = f"{etype}:{evalue}"
|
||||
if entity_key in self._entity_to_memories:
|
||||
ids = self._entity_to_memories[entity_key]
|
||||
if matching_ids is None:
|
||||
matching_ids = ids.copy()
|
||||
elif match_all:
|
||||
matching_ids &= ids
|
||||
else:
|
||||
matching_ids |= ids
|
||||
elif match_all:
|
||||
# Required entity not found
|
||||
matching_ids = set()
|
||||
break
|
||||
elif entity_value:
|
||||
# Search for value across all types
|
||||
matching_ids = set()
|
||||
for entity_key, ids in self._entity_to_memories.items():
|
||||
if entity_value.lower() in entity_key.lower():
|
||||
matching_ids |= ids
|
||||
|
||||
if matching_ids is None:
|
||||
matching_ids = set(self._entries.keys())
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
results = []
|
||||
for mid in matching_ids:
|
||||
if mid in self._entries:
|
||||
entry = self._entries[mid]
|
||||
if memory_type and entry.memory_type != memory_type:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
logger.debug(f"Entity search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
self._entity_to_memories.clear()
|
||||
self._memory_to_entities.clear()
|
||||
logger.info(f"Cleared {count} entries from entity index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]:
|
||||
"""Get all entities for a memory item."""
|
||||
if memory_id not in self._memory_to_entities:
|
||||
return []
|
||||
|
||||
entities = []
|
||||
for entity_key in self._memory_to_entities[memory_id]:
|
||||
if ":" in entity_key:
|
||||
etype, evalue = entity_key.split(":", 1)
|
||||
entities.append((etype, evalue))
|
||||
return entities
|
||||
|
||||
def _extract_entities(self, item: T) -> list[tuple[str, str]]:
|
||||
"""Extract entities from a memory item."""
|
||||
entities: list[tuple[str, str]] = []
|
||||
|
||||
if isinstance(item, Episode):
|
||||
# Extract from task type and context
|
||||
entities.append(("task_type", item.task_type))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
if item.agent_instance_id:
|
||||
entities.append(("agent_instance", str(item.agent_instance_id)))
|
||||
if item.agent_type_id:
|
||||
entities.append(("agent_type", str(item.agent_type_id)))
|
||||
|
||||
elif isinstance(item, Fact):
|
||||
# Subject and object are entities
|
||||
entities.append(("subject", item.subject))
|
||||
entities.append(("object", item.object))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
|
||||
elif isinstance(item, Procedure):
|
||||
entities.append(("procedure", item.name))
|
||||
if item.project_id:
|
||||
entities.append(("project", str(item.project_id)))
|
||||
if item.agent_type_id:
|
||||
entities.append(("agent_type", str(item.agent_type_id)))
|
||||
|
||||
return entities
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
class OutcomeIndex(MemoryIndex[T]):
|
||||
"""
|
||||
Outcome-based index for filtering by success/failure.
|
||||
|
||||
Primarily used for episodes and procedures.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the outcome index."""
|
||||
self._entries: dict[UUID, OutcomeIndexEntry] = {}
|
||||
# Inverted index by outcome
|
||||
self._outcome_to_memories: dict[Outcome, set[UUID]] = {
|
||||
Outcome.SUCCESS: set(),
|
||||
Outcome.FAILURE: set(),
|
||||
Outcome.PARTIAL: set(),
|
||||
}
|
||||
logger.info("Initialized OutcomeIndex")
|
||||
|
||||
async def add(self, item: T) -> OutcomeIndexEntry:
|
||||
"""
|
||||
Add an item to the outcome index.
|
||||
|
||||
Args:
|
||||
item: Memory item with outcome information
|
||||
|
||||
Returns:
|
||||
The created index entry
|
||||
"""
|
||||
outcome = self._get_outcome(item)
|
||||
|
||||
entry = OutcomeIndexEntry(
|
||||
memory_id=item.id,
|
||||
memory_type=self._get_memory_type(item),
|
||||
outcome=outcome,
|
||||
)
|
||||
|
||||
self._entries[item.id] = entry
|
||||
self._outcome_to_memories[outcome].add(item.id)
|
||||
|
||||
logger.debug(f"Added {item.id} to outcome index with {outcome.value}")
|
||||
return entry
|
||||
|
||||
async def remove(self, memory_id: UUID) -> bool:
|
||||
"""Remove an item from the outcome index."""
|
||||
if memory_id not in self._entries:
|
||||
return False
|
||||
|
||||
entry = self._entries.pop(memory_id)
|
||||
self._outcome_to_memories[entry.outcome].discard(memory_id)
|
||||
|
||||
logger.debug(f"Removed {memory_id} from outcome index")
|
||||
return True
|
||||
|
||||
async def search( # type: ignore[override]
|
||||
self,
|
||||
query: Any,
|
||||
limit: int = 10,
|
||||
outcome: Outcome | None = None,
|
||||
outcomes: list[Outcome] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[OutcomeIndexEntry]:
|
||||
"""
|
||||
Search for items by outcome.
|
||||
|
||||
Args:
|
||||
query: Ignored for outcome search
|
||||
limit: Maximum results to return
|
||||
outcome: Single outcome to filter
|
||||
outcomes: Multiple outcomes to filter (OR)
|
||||
**kwargs: Additional filter parameters
|
||||
|
||||
Returns:
|
||||
List of matching entries
|
||||
"""
|
||||
if outcome:
|
||||
outcomes = [outcome]
|
||||
|
||||
if outcomes:
|
||||
matching_ids: set[UUID] = set()
|
||||
for o in outcomes:
|
||||
matching_ids |= self._outcome_to_memories.get(o, set())
|
||||
else:
|
||||
matching_ids = set(self._entries.keys())
|
||||
|
||||
# Apply memory type filter if provided
|
||||
memory_type = kwargs.get("memory_type")
|
||||
results = []
|
||||
for mid in matching_ids:
|
||||
if mid in self._entries:
|
||||
entry = self._entries[mid]
|
||||
if memory_type and entry.memory_type != memory_type:
|
||||
continue
|
||||
results.append(entry)
|
||||
|
||||
logger.debug(f"Outcome search returned {min(len(results), limit)} results")
|
||||
return results[:limit]
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all entries from the index."""
|
||||
count = len(self._entries)
|
||||
self._entries.clear()
|
||||
for outcome in self._outcome_to_memories:
|
||||
self._outcome_to_memories[outcome].clear()
|
||||
logger.info(f"Cleared {count} entries from outcome index")
|
||||
return count
|
||||
|
||||
async def count(self) -> int:
|
||||
"""Get the number of entries in the index."""
|
||||
return len(self._entries)
|
||||
|
||||
async def get_outcome_stats(self) -> dict[Outcome, int]:
|
||||
"""Get statistics on outcomes."""
|
||||
return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()}
|
||||
|
||||
def _get_outcome(self, item: T) -> Outcome:
|
||||
"""Get the outcome for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return item.outcome
|
||||
elif isinstance(item, Procedure):
|
||||
# Derive from success rate
|
||||
if item.success_rate >= 0.8:
|
||||
return Outcome.SUCCESS
|
||||
elif item.success_rate <= 0.2:
|
||||
return Outcome.FAILURE
|
||||
return Outcome.PARTIAL
|
||||
return Outcome.SUCCESS
|
||||
|
||||
def _get_memory_type(self, item: T) -> MemoryType:
|
||||
"""Get the memory type for an item."""
|
||||
if isinstance(item, Episode):
|
||||
return MemoryType.EPISODIC
|
||||
elif isinstance(item, Fact):
|
||||
return MemoryType.SEMANTIC
|
||||
elif isinstance(item, Procedure):
|
||||
return MemoryType.PROCEDURAL
|
||||
return MemoryType.WORKING
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryIndexer:
|
||||
"""
|
||||
Unified indexer that manages all index types.
|
||||
|
||||
Provides a single interface for indexing and searching across
|
||||
multiple index types.
|
||||
"""
|
||||
|
||||
vector_index: VectorIndex[Any] = field(default_factory=VectorIndex)
|
||||
temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex)
|
||||
entity_index: EntityIndex[Any] = field(default_factory=EntityIndex)
|
||||
outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex)
|
||||
|
||||
async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]:
|
||||
"""
|
||||
Index an item across all applicable indices.
|
||||
|
||||
Args:
|
||||
item: Memory item to index
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to entry
|
||||
"""
|
||||
results: dict[str, IndexEntry] = {}
|
||||
|
||||
# Vector index (if embedding present)
|
||||
if getattr(item, "embedding", None):
|
||||
results["vector"] = await self.vector_index.add(item)
|
||||
|
||||
# Temporal index
|
||||
results["temporal"] = await self.temporal_index.add(item)
|
||||
|
||||
# Entity index
|
||||
results["entity"] = await self.entity_index.add(item)
|
||||
|
||||
# Outcome index (for episodes and procedures)
|
||||
if isinstance(item, (Episode, Procedure)):
|
||||
results["outcome"] = await self.outcome_index.add(item)
|
||||
|
||||
logger.info(
|
||||
f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}"
|
||||
)
|
||||
return results
|
||||
|
||||
async def remove(self, memory_id: UUID) -> dict[str, bool]:
|
||||
"""
|
||||
Remove an item from all indices.
|
||||
|
||||
Args:
|
||||
memory_id: ID of the memory to remove
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to removal success
|
||||
"""
|
||||
results = {
|
||||
"vector": await self.vector_index.remove(memory_id),
|
||||
"temporal": await self.temporal_index.remove(memory_id),
|
||||
"entity": await self.entity_index.remove(memory_id),
|
||||
"outcome": await self.outcome_index.remove(memory_id),
|
||||
}
|
||||
|
||||
removed_from = [k for k, v in results.items() if v]
|
||||
if removed_from:
|
||||
logger.info(f"Removed {memory_id} from indices: {removed_from}")
|
||||
|
||||
return results
|
||||
|
||||
async def clear_all(self) -> dict[str, int]:
|
||||
"""
|
||||
Clear all indices.
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to count cleared
|
||||
"""
|
||||
return {
|
||||
"vector": await self.vector_index.clear(),
|
||||
"temporal": await self.temporal_index.clear(),
|
||||
"entity": await self.entity_index.clear(),
|
||||
"outcome": await self.outcome_index.clear(),
|
||||
}
|
||||
|
||||
async def get_stats(self) -> dict[str, int]:
|
||||
"""
|
||||
Get statistics for all indices.
|
||||
|
||||
Returns:
|
||||
Dictionary of index type to entry count
|
||||
"""
|
||||
return {
|
||||
"vector": await self.vector_index.count(),
|
||||
"temporal": await self.temporal_index.count(),
|
||||
"entity": await self.entity_index.count(),
|
||||
"outcome": await self.outcome_index.count(),
|
||||
}
|
||||
|
||||
|
||||
# Singleton indexer instance
|
||||
_indexer: MemoryIndexer | None = None
|
||||
|
||||
|
||||
def get_memory_indexer() -> MemoryIndexer:
|
||||
"""Get the singleton memory indexer instance."""
|
||||
global _indexer
|
||||
if _indexer is None:
|
||||
_indexer = MemoryIndexer()
|
||||
return _indexer
|
||||
Reference in New Issue
Block a user