forked from cardosofelipe/fast-next-template
- Create shallow copy of VectorIndexEntry when adding similarity score - Prevents mutation of cached entries that could corrupt shared state 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
859 lines
27 KiB
Python
859 lines
27 KiB
Python
# app/services/memory/indexing/index.py
|
|
"""
|
|
Memory Indexing.
|
|
|
|
Provides multiple indexing strategies for efficient memory retrieval:
|
|
- Vector embeddings for semantic search
|
|
- Temporal index for time-based queries
|
|
- Entity index for entity-based lookups
|
|
- Outcome index for success/failure filtering
|
|
"""
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any, TypeVar
|
|
from uuid import UUID
|
|
|
|
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
T = TypeVar("T", Episode, Fact, Procedure)
|
|
|
|
|
|
def _utcnow() -> datetime:
|
|
"""Get current UTC time as timezone-aware datetime."""
|
|
return datetime.now(UTC)
|
|
|
|
|
|
@dataclass
|
|
class IndexEntry:
|
|
"""A single entry in an index."""
|
|
|
|
memory_id: UUID
|
|
memory_type: MemoryType
|
|
indexed_at: datetime = field(default_factory=_utcnow)
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class VectorIndexEntry(IndexEntry):
|
|
"""An entry with vector embedding."""
|
|
|
|
embedding: list[float] = field(default_factory=list)
|
|
dimension: int = 0
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Set dimension from embedding."""
|
|
if self.embedding:
|
|
self.dimension = len(self.embedding)
|
|
|
|
|
|
@dataclass
|
|
class TemporalIndexEntry(IndexEntry):
|
|
"""An entry indexed by time."""
|
|
|
|
timestamp: datetime = field(default_factory=_utcnow)
|
|
|
|
|
|
@dataclass
|
|
class EntityIndexEntry(IndexEntry):
|
|
"""An entry indexed by entity."""
|
|
|
|
entity_type: str = ""
|
|
entity_value: str = ""
|
|
|
|
|
|
@dataclass
|
|
class OutcomeIndexEntry(IndexEntry):
|
|
"""An entry indexed by outcome."""
|
|
|
|
outcome: Outcome = Outcome.SUCCESS
|
|
|
|
|
|
class MemoryIndex[T](ABC):
|
|
"""Abstract base class for memory indices."""
|
|
|
|
@abstractmethod
|
|
async def add(self, item: T) -> IndexEntry:
|
|
"""Add an item to the index."""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def remove(self, memory_id: UUID) -> bool:
|
|
"""Remove an item from the index."""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def search(
|
|
self,
|
|
query: Any,
|
|
limit: int = 10,
|
|
**kwargs: Any,
|
|
) -> list[IndexEntry]:
|
|
"""Search the index."""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def clear(self) -> int:
|
|
"""Clear all entries from the index."""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def count(self) -> int:
|
|
"""Get the number of entries in the index."""
|
|
...
|
|
|
|
|
|
class VectorIndex(MemoryIndex[T]):
|
|
"""
|
|
Vector-based index using embeddings for semantic similarity search.
|
|
|
|
Uses cosine similarity for matching.
|
|
"""
|
|
|
|
def __init__(self, dimension: int = 1536) -> None:
|
|
"""
|
|
Initialize the vector index.
|
|
|
|
Args:
|
|
dimension: Embedding dimension (default 1536 for OpenAI)
|
|
"""
|
|
self._dimension = dimension
|
|
self._entries: dict[UUID, VectorIndexEntry] = {}
|
|
logger.info(f"Initialized VectorIndex with dimension={dimension}")
|
|
|
|
async def add(self, item: T) -> VectorIndexEntry:
|
|
"""
|
|
Add an item to the vector index.
|
|
|
|
Args:
|
|
item: Memory item with embedding
|
|
|
|
Returns:
|
|
The created index entry
|
|
"""
|
|
embedding = getattr(item, "embedding", None) or []
|
|
|
|
entry = VectorIndexEntry(
|
|
memory_id=item.id,
|
|
memory_type=self._get_memory_type(item),
|
|
embedding=embedding,
|
|
dimension=len(embedding),
|
|
)
|
|
|
|
self._entries[item.id] = entry
|
|
logger.debug(f"Added {item.id} to vector index")
|
|
return entry
|
|
|
|
async def remove(self, memory_id: UUID) -> bool:
|
|
"""Remove an item from the vector index."""
|
|
if memory_id in self._entries:
|
|
del self._entries[memory_id]
|
|
logger.debug(f"Removed {memory_id} from vector index")
|
|
return True
|
|
return False
|
|
|
|
async def search( # type: ignore[override]
|
|
self,
|
|
query: Any,
|
|
limit: int = 10,
|
|
min_similarity: float = 0.0,
|
|
**kwargs: Any,
|
|
) -> list[VectorIndexEntry]:
|
|
"""
|
|
Search for similar items using vector similarity.
|
|
|
|
Args:
|
|
query: Query embedding vector
|
|
limit: Maximum results to return
|
|
min_similarity: Minimum similarity threshold (0-1)
|
|
**kwargs: Additional filter parameters
|
|
|
|
Returns:
|
|
List of matching entries sorted by similarity
|
|
"""
|
|
if not isinstance(query, list) or not query:
|
|
return []
|
|
|
|
results: list[tuple[float, VectorIndexEntry]] = []
|
|
|
|
for entry in self._entries.values():
|
|
if not entry.embedding:
|
|
continue
|
|
|
|
similarity = self._cosine_similarity(query, entry.embedding)
|
|
if similarity >= min_similarity:
|
|
results.append((similarity, entry))
|
|
|
|
# Sort by similarity descending
|
|
results.sort(key=lambda x: x[0], reverse=True)
|
|
|
|
# Apply memory type filter if provided
|
|
memory_type = kwargs.get("memory_type")
|
|
if memory_type:
|
|
results = [(s, e) for s, e in results if e.memory_type == memory_type]
|
|
|
|
# Store similarity in metadata for the returned entries
|
|
# Use a copy of metadata to avoid mutating cached entries
|
|
output = []
|
|
for similarity, entry in results[:limit]:
|
|
# Create a shallow copy of the entry with updated metadata
|
|
entry_with_score = VectorIndexEntry(
|
|
memory_id=entry.memory_id,
|
|
memory_type=entry.memory_type,
|
|
embedding=entry.embedding,
|
|
metadata={**entry.metadata, "similarity": similarity},
|
|
)
|
|
output.append(entry_with_score)
|
|
|
|
logger.debug(f"Vector search returned {len(output)} results")
|
|
return output
|
|
|
|
async def clear(self) -> int:
|
|
"""Clear all entries from the index."""
|
|
count = len(self._entries)
|
|
self._entries.clear()
|
|
logger.info(f"Cleared {count} entries from vector index")
|
|
return count
|
|
|
|
async def count(self) -> int:
|
|
"""Get the number of entries in the index."""
|
|
return len(self._entries)
|
|
|
|
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
|
|
"""Calculate cosine similarity between two vectors."""
|
|
if len(a) != len(b) or len(a) == 0:
|
|
return 0.0
|
|
|
|
dot_product = sum(x * y for x, y in zip(a, b, strict=True))
|
|
norm_a = sum(x * x for x in a) ** 0.5
|
|
norm_b = sum(x * x for x in b) ** 0.5
|
|
|
|
if norm_a == 0 or norm_b == 0:
|
|
return 0.0
|
|
|
|
return dot_product / (norm_a * norm_b)
|
|
|
|
def _get_memory_type(self, item: T) -> MemoryType:
|
|
"""Get the memory type for an item."""
|
|
if isinstance(item, Episode):
|
|
return MemoryType.EPISODIC
|
|
elif isinstance(item, Fact):
|
|
return MemoryType.SEMANTIC
|
|
elif isinstance(item, Procedure):
|
|
return MemoryType.PROCEDURAL
|
|
return MemoryType.WORKING
|
|
|
|
|
|
class TemporalIndex(MemoryIndex[T]):
|
|
"""
|
|
Time-based index for efficient temporal queries.
|
|
|
|
Supports:
|
|
- Range queries (between timestamps)
|
|
- Recent items (within last N seconds/hours/days)
|
|
- Oldest/newest sorting
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize the temporal index."""
|
|
self._entries: dict[UUID, TemporalIndexEntry] = {}
|
|
# Sorted list for efficient range queries
|
|
self._sorted_entries: list[tuple[datetime, UUID]] = []
|
|
logger.info("Initialized TemporalIndex")
|
|
|
|
async def add(self, item: T) -> TemporalIndexEntry:
|
|
"""
|
|
Add an item to the temporal index.
|
|
|
|
Args:
|
|
item: Memory item with timestamp
|
|
|
|
Returns:
|
|
The created index entry
|
|
"""
|
|
# Get timestamp from various possible fields
|
|
timestamp = self._get_timestamp(item)
|
|
|
|
entry = TemporalIndexEntry(
|
|
memory_id=item.id,
|
|
memory_type=self._get_memory_type(item),
|
|
timestamp=timestamp,
|
|
)
|
|
|
|
self._entries[item.id] = entry
|
|
self._insert_sorted(timestamp, item.id)
|
|
|
|
logger.debug(f"Added {item.id} to temporal index at {timestamp}")
|
|
return entry
|
|
|
|
async def remove(self, memory_id: UUID) -> bool:
|
|
"""Remove an item from the temporal index."""
|
|
if memory_id not in self._entries:
|
|
return False
|
|
|
|
self._entries.pop(memory_id)
|
|
self._sorted_entries = [
|
|
(ts, mid) for ts, mid in self._sorted_entries if mid != memory_id
|
|
]
|
|
|
|
logger.debug(f"Removed {memory_id} from temporal index")
|
|
return True
|
|
|
|
async def search( # type: ignore[override]
|
|
self,
|
|
query: Any,
|
|
limit: int = 10,
|
|
start_time: datetime | None = None,
|
|
end_time: datetime | None = None,
|
|
recent_seconds: float | None = None,
|
|
order: str = "desc",
|
|
**kwargs: Any,
|
|
) -> list[TemporalIndexEntry]:
|
|
"""
|
|
Search for items by time.
|
|
|
|
Args:
|
|
query: Ignored for temporal search
|
|
limit: Maximum results to return
|
|
start_time: Start of time range
|
|
end_time: End of time range
|
|
recent_seconds: Get items from last N seconds
|
|
order: Sort order ("asc" or "desc")
|
|
**kwargs: Additional filter parameters
|
|
|
|
Returns:
|
|
List of matching entries sorted by time
|
|
"""
|
|
if recent_seconds is not None:
|
|
start_time = _utcnow() - timedelta(seconds=recent_seconds)
|
|
end_time = _utcnow()
|
|
|
|
# Filter by time range
|
|
results: list[TemporalIndexEntry] = []
|
|
for entry in self._entries.values():
|
|
if start_time and entry.timestamp < start_time:
|
|
continue
|
|
if end_time and entry.timestamp > end_time:
|
|
continue
|
|
results.append(entry)
|
|
|
|
# Apply memory type filter if provided
|
|
memory_type = kwargs.get("memory_type")
|
|
if memory_type:
|
|
results = [e for e in results if e.memory_type == memory_type]
|
|
|
|
# Sort by timestamp
|
|
results.sort(key=lambda e: e.timestamp, reverse=(order == "desc"))
|
|
|
|
logger.debug(f"Temporal search returned {min(len(results), limit)} results")
|
|
return results[:limit]
|
|
|
|
async def clear(self) -> int:
|
|
"""Clear all entries from the index."""
|
|
count = len(self._entries)
|
|
self._entries.clear()
|
|
self._sorted_entries.clear()
|
|
logger.info(f"Cleared {count} entries from temporal index")
|
|
return count
|
|
|
|
async def count(self) -> int:
|
|
"""Get the number of entries in the index."""
|
|
return len(self._entries)
|
|
|
|
def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None:
|
|
"""Insert entry maintaining sorted order."""
|
|
# Binary search insert for efficiency
|
|
low, high = 0, len(self._sorted_entries)
|
|
while low < high:
|
|
mid = (low + high) // 2
|
|
if self._sorted_entries[mid][0] < timestamp:
|
|
low = mid + 1
|
|
else:
|
|
high = mid
|
|
self._sorted_entries.insert(low, (timestamp, memory_id))
|
|
|
|
def _get_timestamp(self, item: T) -> datetime:
|
|
"""Get the relevant timestamp for an item."""
|
|
if hasattr(item, "occurred_at"):
|
|
return item.occurred_at
|
|
if hasattr(item, "first_learned"):
|
|
return item.first_learned
|
|
if hasattr(item, "last_used") and item.last_used:
|
|
return item.last_used
|
|
if hasattr(item, "created_at"):
|
|
return item.created_at
|
|
return _utcnow()
|
|
|
|
def _get_memory_type(self, item: T) -> MemoryType:
|
|
"""Get the memory type for an item."""
|
|
if isinstance(item, Episode):
|
|
return MemoryType.EPISODIC
|
|
elif isinstance(item, Fact):
|
|
return MemoryType.SEMANTIC
|
|
elif isinstance(item, Procedure):
|
|
return MemoryType.PROCEDURAL
|
|
return MemoryType.WORKING
|
|
|
|
|
|
class EntityIndex(MemoryIndex[T]):
|
|
"""
|
|
Entity-based index for lookups by entities mentioned in memories.
|
|
|
|
Supports:
|
|
- Single entity lookup
|
|
- Multi-entity intersection
|
|
- Entity type filtering
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize the entity index."""
|
|
# Main storage
|
|
self._entries: dict[UUID, EntityIndexEntry] = {}
|
|
# Inverted index: entity -> set of memory IDs
|
|
self._entity_to_memories: dict[str, set[UUID]] = {}
|
|
# Memory to entities mapping
|
|
self._memory_to_entities: dict[UUID, set[str]] = {}
|
|
logger.info("Initialized EntityIndex")
|
|
|
|
async def add(self, item: T) -> EntityIndexEntry:
|
|
"""
|
|
Add an item to the entity index.
|
|
|
|
Args:
|
|
item: Memory item with entity information
|
|
|
|
Returns:
|
|
The created index entry
|
|
"""
|
|
entities = self._extract_entities(item)
|
|
|
|
# Create entry for the primary entity (or first one)
|
|
primary_entity = entities[0] if entities else ("unknown", "unknown")
|
|
|
|
entry = EntityIndexEntry(
|
|
memory_id=item.id,
|
|
memory_type=self._get_memory_type(item),
|
|
entity_type=primary_entity[0],
|
|
entity_value=primary_entity[1],
|
|
)
|
|
|
|
self._entries[item.id] = entry
|
|
|
|
# Update inverted indices
|
|
entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities}
|
|
self._memory_to_entities[item.id] = entity_keys
|
|
|
|
for entity_key in entity_keys:
|
|
if entity_key not in self._entity_to_memories:
|
|
self._entity_to_memories[entity_key] = set()
|
|
self._entity_to_memories[entity_key].add(item.id)
|
|
|
|
logger.debug(f"Added {item.id} to entity index with {len(entities)} entities")
|
|
return entry
|
|
|
|
async def remove(self, memory_id: UUID) -> bool:
|
|
"""Remove an item from the entity index."""
|
|
if memory_id not in self._entries:
|
|
return False
|
|
|
|
# Remove from inverted index
|
|
if memory_id in self._memory_to_entities:
|
|
for entity_key in self._memory_to_entities[memory_id]:
|
|
if entity_key in self._entity_to_memories:
|
|
self._entity_to_memories[entity_key].discard(memory_id)
|
|
if not self._entity_to_memories[entity_key]:
|
|
del self._entity_to_memories[entity_key]
|
|
del self._memory_to_entities[memory_id]
|
|
|
|
del self._entries[memory_id]
|
|
logger.debug(f"Removed {memory_id} from entity index")
|
|
return True
|
|
|
|
async def search( # type: ignore[override]
|
|
self,
|
|
query: Any,
|
|
limit: int = 10,
|
|
entity_type: str | None = None,
|
|
entity_value: str | None = None,
|
|
entities: list[tuple[str, str]] | None = None,
|
|
match_all: bool = False,
|
|
**kwargs: Any,
|
|
) -> list[EntityIndexEntry]:
|
|
"""
|
|
Search for items by entity.
|
|
|
|
Args:
|
|
query: Entity value to search (if entity_type not specified)
|
|
limit: Maximum results to return
|
|
entity_type: Type of entity to filter
|
|
entity_value: Specific entity value
|
|
entities: List of (type, value) tuples to match
|
|
match_all: If True, require all entities to match
|
|
**kwargs: Additional filter parameters
|
|
|
|
Returns:
|
|
List of matching entries
|
|
"""
|
|
matching_ids: set[UUID] | None = None
|
|
|
|
# Handle single entity query
|
|
if entity_type and entity_value:
|
|
entities = [(entity_type, entity_value)]
|
|
elif entity_value is None and isinstance(query, str):
|
|
# Search across all entity types
|
|
entity_value = query
|
|
|
|
if entities:
|
|
for etype, evalue in entities:
|
|
entity_key = f"{etype}:{evalue}"
|
|
if entity_key in self._entity_to_memories:
|
|
ids = self._entity_to_memories[entity_key]
|
|
if matching_ids is None:
|
|
matching_ids = ids.copy()
|
|
elif match_all:
|
|
matching_ids &= ids
|
|
else:
|
|
matching_ids |= ids
|
|
elif match_all:
|
|
# Required entity not found
|
|
matching_ids = set()
|
|
break
|
|
elif entity_value:
|
|
# Search for value across all types
|
|
matching_ids = set()
|
|
for entity_key, ids in self._entity_to_memories.items():
|
|
if entity_value.lower() in entity_key.lower():
|
|
matching_ids |= ids
|
|
|
|
if matching_ids is None:
|
|
matching_ids = set(self._entries.keys())
|
|
|
|
# Apply memory type filter if provided
|
|
memory_type = kwargs.get("memory_type")
|
|
results = []
|
|
for mid in matching_ids:
|
|
if mid in self._entries:
|
|
entry = self._entries[mid]
|
|
if memory_type and entry.memory_type != memory_type:
|
|
continue
|
|
results.append(entry)
|
|
|
|
logger.debug(f"Entity search returned {min(len(results), limit)} results")
|
|
return results[:limit]
|
|
|
|
async def clear(self) -> int:
|
|
"""Clear all entries from the index."""
|
|
count = len(self._entries)
|
|
self._entries.clear()
|
|
self._entity_to_memories.clear()
|
|
self._memory_to_entities.clear()
|
|
logger.info(f"Cleared {count} entries from entity index")
|
|
return count
|
|
|
|
async def count(self) -> int:
|
|
"""Get the number of entries in the index."""
|
|
return len(self._entries)
|
|
|
|
async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]:
|
|
"""Get all entities for a memory item."""
|
|
if memory_id not in self._memory_to_entities:
|
|
return []
|
|
|
|
entities = []
|
|
for entity_key in self._memory_to_entities[memory_id]:
|
|
if ":" in entity_key:
|
|
etype, evalue = entity_key.split(":", 1)
|
|
entities.append((etype, evalue))
|
|
return entities
|
|
|
|
def _extract_entities(self, item: T) -> list[tuple[str, str]]:
|
|
"""Extract entities from a memory item."""
|
|
entities: list[tuple[str, str]] = []
|
|
|
|
if isinstance(item, Episode):
|
|
# Extract from task type and context
|
|
entities.append(("task_type", item.task_type))
|
|
if item.project_id:
|
|
entities.append(("project", str(item.project_id)))
|
|
if item.agent_instance_id:
|
|
entities.append(("agent_instance", str(item.agent_instance_id)))
|
|
if item.agent_type_id:
|
|
entities.append(("agent_type", str(item.agent_type_id)))
|
|
|
|
elif isinstance(item, Fact):
|
|
# Subject and object are entities
|
|
entities.append(("subject", item.subject))
|
|
entities.append(("object", item.object))
|
|
if item.project_id:
|
|
entities.append(("project", str(item.project_id)))
|
|
|
|
elif isinstance(item, Procedure):
|
|
entities.append(("procedure", item.name))
|
|
if item.project_id:
|
|
entities.append(("project", str(item.project_id)))
|
|
if item.agent_type_id:
|
|
entities.append(("agent_type", str(item.agent_type_id)))
|
|
|
|
return entities
|
|
|
|
def _get_memory_type(self, item: T) -> MemoryType:
|
|
"""Get the memory type for an item."""
|
|
if isinstance(item, Episode):
|
|
return MemoryType.EPISODIC
|
|
elif isinstance(item, Fact):
|
|
return MemoryType.SEMANTIC
|
|
elif isinstance(item, Procedure):
|
|
return MemoryType.PROCEDURAL
|
|
return MemoryType.WORKING
|
|
|
|
|
|
class OutcomeIndex(MemoryIndex[T]):
|
|
"""
|
|
Outcome-based index for filtering by success/failure.
|
|
|
|
Primarily used for episodes and procedures.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize the outcome index."""
|
|
self._entries: dict[UUID, OutcomeIndexEntry] = {}
|
|
# Inverted index by outcome
|
|
self._outcome_to_memories: dict[Outcome, set[UUID]] = {
|
|
Outcome.SUCCESS: set(),
|
|
Outcome.FAILURE: set(),
|
|
Outcome.PARTIAL: set(),
|
|
}
|
|
logger.info("Initialized OutcomeIndex")
|
|
|
|
async def add(self, item: T) -> OutcomeIndexEntry:
|
|
"""
|
|
Add an item to the outcome index.
|
|
|
|
Args:
|
|
item: Memory item with outcome information
|
|
|
|
Returns:
|
|
The created index entry
|
|
"""
|
|
outcome = self._get_outcome(item)
|
|
|
|
entry = OutcomeIndexEntry(
|
|
memory_id=item.id,
|
|
memory_type=self._get_memory_type(item),
|
|
outcome=outcome,
|
|
)
|
|
|
|
self._entries[item.id] = entry
|
|
self._outcome_to_memories[outcome].add(item.id)
|
|
|
|
logger.debug(f"Added {item.id} to outcome index with {outcome.value}")
|
|
return entry
|
|
|
|
async def remove(self, memory_id: UUID) -> bool:
|
|
"""Remove an item from the outcome index."""
|
|
if memory_id not in self._entries:
|
|
return False
|
|
|
|
entry = self._entries.pop(memory_id)
|
|
self._outcome_to_memories[entry.outcome].discard(memory_id)
|
|
|
|
logger.debug(f"Removed {memory_id} from outcome index")
|
|
return True
|
|
|
|
async def search( # type: ignore[override]
|
|
self,
|
|
query: Any,
|
|
limit: int = 10,
|
|
outcome: Outcome | None = None,
|
|
outcomes: list[Outcome] | None = None,
|
|
**kwargs: Any,
|
|
) -> list[OutcomeIndexEntry]:
|
|
"""
|
|
Search for items by outcome.
|
|
|
|
Args:
|
|
query: Ignored for outcome search
|
|
limit: Maximum results to return
|
|
outcome: Single outcome to filter
|
|
outcomes: Multiple outcomes to filter (OR)
|
|
**kwargs: Additional filter parameters
|
|
|
|
Returns:
|
|
List of matching entries
|
|
"""
|
|
if outcome:
|
|
outcomes = [outcome]
|
|
|
|
if outcomes:
|
|
matching_ids: set[UUID] = set()
|
|
for o in outcomes:
|
|
matching_ids |= self._outcome_to_memories.get(o, set())
|
|
else:
|
|
matching_ids = set(self._entries.keys())
|
|
|
|
# Apply memory type filter if provided
|
|
memory_type = kwargs.get("memory_type")
|
|
results = []
|
|
for mid in matching_ids:
|
|
if mid in self._entries:
|
|
entry = self._entries[mid]
|
|
if memory_type and entry.memory_type != memory_type:
|
|
continue
|
|
results.append(entry)
|
|
|
|
logger.debug(f"Outcome search returned {min(len(results), limit)} results")
|
|
return results[:limit]
|
|
|
|
async def clear(self) -> int:
|
|
"""Clear all entries from the index."""
|
|
count = len(self._entries)
|
|
self._entries.clear()
|
|
for outcome in self._outcome_to_memories:
|
|
self._outcome_to_memories[outcome].clear()
|
|
logger.info(f"Cleared {count} entries from outcome index")
|
|
return count
|
|
|
|
async def count(self) -> int:
|
|
"""Get the number of entries in the index."""
|
|
return len(self._entries)
|
|
|
|
async def get_outcome_stats(self) -> dict[Outcome, int]:
|
|
"""Get statistics on outcomes."""
|
|
return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()}
|
|
|
|
def _get_outcome(self, item: T) -> Outcome:
|
|
"""Get the outcome for an item."""
|
|
if isinstance(item, Episode):
|
|
return item.outcome
|
|
elif isinstance(item, Procedure):
|
|
# Derive from success rate
|
|
if item.success_rate >= 0.8:
|
|
return Outcome.SUCCESS
|
|
elif item.success_rate <= 0.2:
|
|
return Outcome.FAILURE
|
|
return Outcome.PARTIAL
|
|
return Outcome.SUCCESS
|
|
|
|
def _get_memory_type(self, item: T) -> MemoryType:
|
|
"""Get the memory type for an item."""
|
|
if isinstance(item, Episode):
|
|
return MemoryType.EPISODIC
|
|
elif isinstance(item, Fact):
|
|
return MemoryType.SEMANTIC
|
|
elif isinstance(item, Procedure):
|
|
return MemoryType.PROCEDURAL
|
|
return MemoryType.WORKING
|
|
|
|
|
|
@dataclass
|
|
class MemoryIndexer:
|
|
"""
|
|
Unified indexer that manages all index types.
|
|
|
|
Provides a single interface for indexing and searching across
|
|
multiple index types.
|
|
"""
|
|
|
|
vector_index: VectorIndex[Any] = field(default_factory=VectorIndex)
|
|
temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex)
|
|
entity_index: EntityIndex[Any] = field(default_factory=EntityIndex)
|
|
outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex)
|
|
|
|
async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]:
|
|
"""
|
|
Index an item across all applicable indices.
|
|
|
|
Args:
|
|
item: Memory item to index
|
|
|
|
Returns:
|
|
Dictionary of index type to entry
|
|
"""
|
|
results: dict[str, IndexEntry] = {}
|
|
|
|
# Vector index (if embedding present)
|
|
if getattr(item, "embedding", None):
|
|
results["vector"] = await self.vector_index.add(item)
|
|
|
|
# Temporal index
|
|
results["temporal"] = await self.temporal_index.add(item)
|
|
|
|
# Entity index
|
|
results["entity"] = await self.entity_index.add(item)
|
|
|
|
# Outcome index (for episodes and procedures)
|
|
if isinstance(item, (Episode, Procedure)):
|
|
results["outcome"] = await self.outcome_index.add(item)
|
|
|
|
logger.info(
|
|
f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}"
|
|
)
|
|
return results
|
|
|
|
async def remove(self, memory_id: UUID) -> dict[str, bool]:
|
|
"""
|
|
Remove an item from all indices.
|
|
|
|
Args:
|
|
memory_id: ID of the memory to remove
|
|
|
|
Returns:
|
|
Dictionary of index type to removal success
|
|
"""
|
|
results = {
|
|
"vector": await self.vector_index.remove(memory_id),
|
|
"temporal": await self.temporal_index.remove(memory_id),
|
|
"entity": await self.entity_index.remove(memory_id),
|
|
"outcome": await self.outcome_index.remove(memory_id),
|
|
}
|
|
|
|
removed_from = [k for k, v in results.items() if v]
|
|
if removed_from:
|
|
logger.info(f"Removed {memory_id} from indices: {removed_from}")
|
|
|
|
return results
|
|
|
|
async def clear_all(self) -> dict[str, int]:
|
|
"""
|
|
Clear all indices.
|
|
|
|
Returns:
|
|
Dictionary of index type to count cleared
|
|
"""
|
|
return {
|
|
"vector": await self.vector_index.clear(),
|
|
"temporal": await self.temporal_index.clear(),
|
|
"entity": await self.entity_index.clear(),
|
|
"outcome": await self.outcome_index.clear(),
|
|
}
|
|
|
|
async def get_stats(self) -> dict[str, int]:
|
|
"""
|
|
Get statistics for all indices.
|
|
|
|
Returns:
|
|
Dictionary of index type to entry count
|
|
"""
|
|
return {
|
|
"vector": await self.vector_index.count(),
|
|
"temporal": await self.temporal_index.count(),
|
|
"entity": await self.entity_index.count(),
|
|
"outcome": await self.outcome_index.count(),
|
|
}
|
|
|
|
|
|
# Singleton indexer instance
|
|
_indexer: MemoryIndexer | None = None
|
|
|
|
|
|
def get_memory_indexer() -> MemoryIndexer:
|
|
"""Get the singleton memory indexer instance."""
|
|
global _indexer
|
|
if _indexer is None:
|
|
_indexer = MemoryIndexer()
|
|
return _indexer
|