Files
syndarix/backend/app/services/memory/indexing/index.py
Felipe Cardoso da85a8aba8 fix(memory): prevent entry metadata mutation in vector search
- Create shallow copy of VectorIndexEntry when adding similarity score
- Prevents mutation of cached entries that could corrupt shared state

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 17:39:54 +01:00

859 lines
27 KiB
Python

# app/services/memory/indexing/index.py
"""
Memory Indexing.
Provides multiple indexing strategies for efficient memory retrieval:
- Vector embeddings for semantic search
- Temporal index for time-based queries
- Entity index for entity-based lookups
- Outcome index for success/failure filtering
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import Any, TypeVar
from uuid import UUID
from app.services.memory.types import Episode, Fact, MemoryType, Outcome, Procedure
logger = logging.getLogger(__name__)
T = TypeVar("T", Episode, Fact, Procedure)
def _utcnow() -> datetime:
"""Get current UTC time as timezone-aware datetime."""
return datetime.now(UTC)
@dataclass
class IndexEntry:
"""A single entry in an index."""
memory_id: UUID
memory_type: MemoryType
indexed_at: datetime = field(default_factory=_utcnow)
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class VectorIndexEntry(IndexEntry):
"""An entry with vector embedding."""
embedding: list[float] = field(default_factory=list)
dimension: int = 0
def __post_init__(self) -> None:
"""Set dimension from embedding."""
if self.embedding:
self.dimension = len(self.embedding)
@dataclass
class TemporalIndexEntry(IndexEntry):
"""An entry indexed by time."""
timestamp: datetime = field(default_factory=_utcnow)
@dataclass
class EntityIndexEntry(IndexEntry):
"""An entry indexed by entity."""
entity_type: str = ""
entity_value: str = ""
@dataclass
class OutcomeIndexEntry(IndexEntry):
"""An entry indexed by outcome."""
outcome: Outcome = Outcome.SUCCESS
class MemoryIndex[T](ABC):
"""Abstract base class for memory indices."""
@abstractmethod
async def add(self, item: T) -> IndexEntry:
"""Add an item to the index."""
...
@abstractmethod
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the index."""
...
@abstractmethod
async def search(
self,
query: Any,
limit: int = 10,
**kwargs: Any,
) -> list[IndexEntry]:
"""Search the index."""
...
@abstractmethod
async def clear(self) -> int:
"""Clear all entries from the index."""
...
@abstractmethod
async def count(self) -> int:
"""Get the number of entries in the index."""
...
class VectorIndex(MemoryIndex[T]):
"""
Vector-based index using embeddings for semantic similarity search.
Uses cosine similarity for matching.
"""
def __init__(self, dimension: int = 1536) -> None:
"""
Initialize the vector index.
Args:
dimension: Embedding dimension (default 1536 for OpenAI)
"""
self._dimension = dimension
self._entries: dict[UUID, VectorIndexEntry] = {}
logger.info(f"Initialized VectorIndex with dimension={dimension}")
async def add(self, item: T) -> VectorIndexEntry:
"""
Add an item to the vector index.
Args:
item: Memory item with embedding
Returns:
The created index entry
"""
embedding = getattr(item, "embedding", None) or []
entry = VectorIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
embedding=embedding,
dimension=len(embedding),
)
self._entries[item.id] = entry
logger.debug(f"Added {item.id} to vector index")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the vector index."""
if memory_id in self._entries:
del self._entries[memory_id]
logger.debug(f"Removed {memory_id} from vector index")
return True
return False
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
min_similarity: float = 0.0,
**kwargs: Any,
) -> list[VectorIndexEntry]:
"""
Search for similar items using vector similarity.
Args:
query: Query embedding vector
limit: Maximum results to return
min_similarity: Minimum similarity threshold (0-1)
**kwargs: Additional filter parameters
Returns:
List of matching entries sorted by similarity
"""
if not isinstance(query, list) or not query:
return []
results: list[tuple[float, VectorIndexEntry]] = []
for entry in self._entries.values():
if not entry.embedding:
continue
similarity = self._cosine_similarity(query, entry.embedding)
if similarity >= min_similarity:
results.append((similarity, entry))
# Sort by similarity descending
results.sort(key=lambda x: x[0], reverse=True)
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
if memory_type:
results = [(s, e) for s, e in results if e.memory_type == memory_type]
# Store similarity in metadata for the returned entries
# Use a copy of metadata to avoid mutating cached entries
output = []
for similarity, entry in results[:limit]:
# Create a shallow copy of the entry with updated metadata
entry_with_score = VectorIndexEntry(
memory_id=entry.memory_id,
memory_type=entry.memory_type,
embedding=entry.embedding,
metadata={**entry.metadata, "similarity": similarity},
)
output.append(entry_with_score)
logger.debug(f"Vector search returned {len(output)} results")
return output
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
logger.info(f"Cleared {count} entries from vector index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
if len(a) != len(b) or len(a) == 0:
return 0.0
dot_product = sum(x * y for x, y in zip(a, b, strict=True))
norm_a = sum(x * x for x in a) ** 0.5
norm_b = sum(x * x for x in b) ** 0.5
if norm_a == 0 or norm_b == 0:
return 0.0
return dot_product / (norm_a * norm_b)
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
class TemporalIndex(MemoryIndex[T]):
"""
Time-based index for efficient temporal queries.
Supports:
- Range queries (between timestamps)
- Recent items (within last N seconds/hours/days)
- Oldest/newest sorting
"""
def __init__(self) -> None:
"""Initialize the temporal index."""
self._entries: dict[UUID, TemporalIndexEntry] = {}
# Sorted list for efficient range queries
self._sorted_entries: list[tuple[datetime, UUID]] = []
logger.info("Initialized TemporalIndex")
async def add(self, item: T) -> TemporalIndexEntry:
"""
Add an item to the temporal index.
Args:
item: Memory item with timestamp
Returns:
The created index entry
"""
# Get timestamp from various possible fields
timestamp = self._get_timestamp(item)
entry = TemporalIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
timestamp=timestamp,
)
self._entries[item.id] = entry
self._insert_sorted(timestamp, item.id)
logger.debug(f"Added {item.id} to temporal index at {timestamp}")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the temporal index."""
if memory_id not in self._entries:
return False
self._entries.pop(memory_id)
self._sorted_entries = [
(ts, mid) for ts, mid in self._sorted_entries if mid != memory_id
]
logger.debug(f"Removed {memory_id} from temporal index")
return True
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
start_time: datetime | None = None,
end_time: datetime | None = None,
recent_seconds: float | None = None,
order: str = "desc",
**kwargs: Any,
) -> list[TemporalIndexEntry]:
"""
Search for items by time.
Args:
query: Ignored for temporal search
limit: Maximum results to return
start_time: Start of time range
end_time: End of time range
recent_seconds: Get items from last N seconds
order: Sort order ("asc" or "desc")
**kwargs: Additional filter parameters
Returns:
List of matching entries sorted by time
"""
if recent_seconds is not None:
start_time = _utcnow() - timedelta(seconds=recent_seconds)
end_time = _utcnow()
# Filter by time range
results: list[TemporalIndexEntry] = []
for entry in self._entries.values():
if start_time and entry.timestamp < start_time:
continue
if end_time and entry.timestamp > end_time:
continue
results.append(entry)
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
if memory_type:
results = [e for e in results if e.memory_type == memory_type]
# Sort by timestamp
results.sort(key=lambda e: e.timestamp, reverse=(order == "desc"))
logger.debug(f"Temporal search returned {min(len(results), limit)} results")
return results[:limit]
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
self._sorted_entries.clear()
logger.info(f"Cleared {count} entries from temporal index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
def _insert_sorted(self, timestamp: datetime, memory_id: UUID) -> None:
"""Insert entry maintaining sorted order."""
# Binary search insert for efficiency
low, high = 0, len(self._sorted_entries)
while low < high:
mid = (low + high) // 2
if self._sorted_entries[mid][0] < timestamp:
low = mid + 1
else:
high = mid
self._sorted_entries.insert(low, (timestamp, memory_id))
def _get_timestamp(self, item: T) -> datetime:
"""Get the relevant timestamp for an item."""
if hasattr(item, "occurred_at"):
return item.occurred_at
if hasattr(item, "first_learned"):
return item.first_learned
if hasattr(item, "last_used") and item.last_used:
return item.last_used
if hasattr(item, "created_at"):
return item.created_at
return _utcnow()
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
class EntityIndex(MemoryIndex[T]):
"""
Entity-based index for lookups by entities mentioned in memories.
Supports:
- Single entity lookup
- Multi-entity intersection
- Entity type filtering
"""
def __init__(self) -> None:
"""Initialize the entity index."""
# Main storage
self._entries: dict[UUID, EntityIndexEntry] = {}
# Inverted index: entity -> set of memory IDs
self._entity_to_memories: dict[str, set[UUID]] = {}
# Memory to entities mapping
self._memory_to_entities: dict[UUID, set[str]] = {}
logger.info("Initialized EntityIndex")
async def add(self, item: T) -> EntityIndexEntry:
"""
Add an item to the entity index.
Args:
item: Memory item with entity information
Returns:
The created index entry
"""
entities = self._extract_entities(item)
# Create entry for the primary entity (or first one)
primary_entity = entities[0] if entities else ("unknown", "unknown")
entry = EntityIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
entity_type=primary_entity[0],
entity_value=primary_entity[1],
)
self._entries[item.id] = entry
# Update inverted indices
entity_keys = {f"{etype}:{evalue}" for etype, evalue in entities}
self._memory_to_entities[item.id] = entity_keys
for entity_key in entity_keys:
if entity_key not in self._entity_to_memories:
self._entity_to_memories[entity_key] = set()
self._entity_to_memories[entity_key].add(item.id)
logger.debug(f"Added {item.id} to entity index with {len(entities)} entities")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the entity index."""
if memory_id not in self._entries:
return False
# Remove from inverted index
if memory_id in self._memory_to_entities:
for entity_key in self._memory_to_entities[memory_id]:
if entity_key in self._entity_to_memories:
self._entity_to_memories[entity_key].discard(memory_id)
if not self._entity_to_memories[entity_key]:
del self._entity_to_memories[entity_key]
del self._memory_to_entities[memory_id]
del self._entries[memory_id]
logger.debug(f"Removed {memory_id} from entity index")
return True
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
entity_type: str | None = None,
entity_value: str | None = None,
entities: list[tuple[str, str]] | None = None,
match_all: bool = False,
**kwargs: Any,
) -> list[EntityIndexEntry]:
"""
Search for items by entity.
Args:
query: Entity value to search (if entity_type not specified)
limit: Maximum results to return
entity_type: Type of entity to filter
entity_value: Specific entity value
entities: List of (type, value) tuples to match
match_all: If True, require all entities to match
**kwargs: Additional filter parameters
Returns:
List of matching entries
"""
matching_ids: set[UUID] | None = None
# Handle single entity query
if entity_type and entity_value:
entities = [(entity_type, entity_value)]
elif entity_value is None and isinstance(query, str):
# Search across all entity types
entity_value = query
if entities:
for etype, evalue in entities:
entity_key = f"{etype}:{evalue}"
if entity_key in self._entity_to_memories:
ids = self._entity_to_memories[entity_key]
if matching_ids is None:
matching_ids = ids.copy()
elif match_all:
matching_ids &= ids
else:
matching_ids |= ids
elif match_all:
# Required entity not found
matching_ids = set()
break
elif entity_value:
# Search for value across all types
matching_ids = set()
for entity_key, ids in self._entity_to_memories.items():
if entity_value.lower() in entity_key.lower():
matching_ids |= ids
if matching_ids is None:
matching_ids = set(self._entries.keys())
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
results = []
for mid in matching_ids:
if mid in self._entries:
entry = self._entries[mid]
if memory_type and entry.memory_type != memory_type:
continue
results.append(entry)
logger.debug(f"Entity search returned {min(len(results), limit)} results")
return results[:limit]
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
self._entity_to_memories.clear()
self._memory_to_entities.clear()
logger.info(f"Cleared {count} entries from entity index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
async def get_entities(self, memory_id: UUID) -> list[tuple[str, str]]:
"""Get all entities for a memory item."""
if memory_id not in self._memory_to_entities:
return []
entities = []
for entity_key in self._memory_to_entities[memory_id]:
if ":" in entity_key:
etype, evalue = entity_key.split(":", 1)
entities.append((etype, evalue))
return entities
def _extract_entities(self, item: T) -> list[tuple[str, str]]:
"""Extract entities from a memory item."""
entities: list[tuple[str, str]] = []
if isinstance(item, Episode):
# Extract from task type and context
entities.append(("task_type", item.task_type))
if item.project_id:
entities.append(("project", str(item.project_id)))
if item.agent_instance_id:
entities.append(("agent_instance", str(item.agent_instance_id)))
if item.agent_type_id:
entities.append(("agent_type", str(item.agent_type_id)))
elif isinstance(item, Fact):
# Subject and object are entities
entities.append(("subject", item.subject))
entities.append(("object", item.object))
if item.project_id:
entities.append(("project", str(item.project_id)))
elif isinstance(item, Procedure):
entities.append(("procedure", item.name))
if item.project_id:
entities.append(("project", str(item.project_id)))
if item.agent_type_id:
entities.append(("agent_type", str(item.agent_type_id)))
return entities
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
class OutcomeIndex(MemoryIndex[T]):
"""
Outcome-based index for filtering by success/failure.
Primarily used for episodes and procedures.
"""
def __init__(self) -> None:
"""Initialize the outcome index."""
self._entries: dict[UUID, OutcomeIndexEntry] = {}
# Inverted index by outcome
self._outcome_to_memories: dict[Outcome, set[UUID]] = {
Outcome.SUCCESS: set(),
Outcome.FAILURE: set(),
Outcome.PARTIAL: set(),
}
logger.info("Initialized OutcomeIndex")
async def add(self, item: T) -> OutcomeIndexEntry:
"""
Add an item to the outcome index.
Args:
item: Memory item with outcome information
Returns:
The created index entry
"""
outcome = self._get_outcome(item)
entry = OutcomeIndexEntry(
memory_id=item.id,
memory_type=self._get_memory_type(item),
outcome=outcome,
)
self._entries[item.id] = entry
self._outcome_to_memories[outcome].add(item.id)
logger.debug(f"Added {item.id} to outcome index with {outcome.value}")
return entry
async def remove(self, memory_id: UUID) -> bool:
"""Remove an item from the outcome index."""
if memory_id not in self._entries:
return False
entry = self._entries.pop(memory_id)
self._outcome_to_memories[entry.outcome].discard(memory_id)
logger.debug(f"Removed {memory_id} from outcome index")
return True
async def search( # type: ignore[override]
self,
query: Any,
limit: int = 10,
outcome: Outcome | None = None,
outcomes: list[Outcome] | None = None,
**kwargs: Any,
) -> list[OutcomeIndexEntry]:
"""
Search for items by outcome.
Args:
query: Ignored for outcome search
limit: Maximum results to return
outcome: Single outcome to filter
outcomes: Multiple outcomes to filter (OR)
**kwargs: Additional filter parameters
Returns:
List of matching entries
"""
if outcome:
outcomes = [outcome]
if outcomes:
matching_ids: set[UUID] = set()
for o in outcomes:
matching_ids |= self._outcome_to_memories.get(o, set())
else:
matching_ids = set(self._entries.keys())
# Apply memory type filter if provided
memory_type = kwargs.get("memory_type")
results = []
for mid in matching_ids:
if mid in self._entries:
entry = self._entries[mid]
if memory_type and entry.memory_type != memory_type:
continue
results.append(entry)
logger.debug(f"Outcome search returned {min(len(results), limit)} results")
return results[:limit]
async def clear(self) -> int:
"""Clear all entries from the index."""
count = len(self._entries)
self._entries.clear()
for outcome in self._outcome_to_memories:
self._outcome_to_memories[outcome].clear()
logger.info(f"Cleared {count} entries from outcome index")
return count
async def count(self) -> int:
"""Get the number of entries in the index."""
return len(self._entries)
async def get_outcome_stats(self) -> dict[Outcome, int]:
"""Get statistics on outcomes."""
return {outcome: len(ids) for outcome, ids in self._outcome_to_memories.items()}
def _get_outcome(self, item: T) -> Outcome:
"""Get the outcome for an item."""
if isinstance(item, Episode):
return item.outcome
elif isinstance(item, Procedure):
# Derive from success rate
if item.success_rate >= 0.8:
return Outcome.SUCCESS
elif item.success_rate <= 0.2:
return Outcome.FAILURE
return Outcome.PARTIAL
return Outcome.SUCCESS
def _get_memory_type(self, item: T) -> MemoryType:
"""Get the memory type for an item."""
if isinstance(item, Episode):
return MemoryType.EPISODIC
elif isinstance(item, Fact):
return MemoryType.SEMANTIC
elif isinstance(item, Procedure):
return MemoryType.PROCEDURAL
return MemoryType.WORKING
@dataclass
class MemoryIndexer:
"""
Unified indexer that manages all index types.
Provides a single interface for indexing and searching across
multiple index types.
"""
vector_index: VectorIndex[Any] = field(default_factory=VectorIndex)
temporal_index: TemporalIndex[Any] = field(default_factory=TemporalIndex)
entity_index: EntityIndex[Any] = field(default_factory=EntityIndex)
outcome_index: OutcomeIndex[Any] = field(default_factory=OutcomeIndex)
async def index(self, item: Episode | Fact | Procedure) -> dict[str, IndexEntry]:
"""
Index an item across all applicable indices.
Args:
item: Memory item to index
Returns:
Dictionary of index type to entry
"""
results: dict[str, IndexEntry] = {}
# Vector index (if embedding present)
if getattr(item, "embedding", None):
results["vector"] = await self.vector_index.add(item)
# Temporal index
results["temporal"] = await self.temporal_index.add(item)
# Entity index
results["entity"] = await self.entity_index.add(item)
# Outcome index (for episodes and procedures)
if isinstance(item, (Episode, Procedure)):
results["outcome"] = await self.outcome_index.add(item)
logger.info(
f"Indexed {item.id} across {len(results)} indices: {list(results.keys())}"
)
return results
async def remove(self, memory_id: UUID) -> dict[str, bool]:
"""
Remove an item from all indices.
Args:
memory_id: ID of the memory to remove
Returns:
Dictionary of index type to removal success
"""
results = {
"vector": await self.vector_index.remove(memory_id),
"temporal": await self.temporal_index.remove(memory_id),
"entity": await self.entity_index.remove(memory_id),
"outcome": await self.outcome_index.remove(memory_id),
}
removed_from = [k for k, v in results.items() if v]
if removed_from:
logger.info(f"Removed {memory_id} from indices: {removed_from}")
return results
async def clear_all(self) -> dict[str, int]:
"""
Clear all indices.
Returns:
Dictionary of index type to count cleared
"""
return {
"vector": await self.vector_index.clear(),
"temporal": await self.temporal_index.clear(),
"entity": await self.entity_index.clear(),
"outcome": await self.outcome_index.clear(),
}
async def get_stats(self) -> dict[str, int]:
"""
Get statistics for all indices.
Returns:
Dictionary of index type to entry count
"""
return {
"vector": await self.vector_index.count(),
"temporal": await self.temporal_index.count(),
"entity": await self.entity_index.count(),
"outcome": await self.outcome_index.count(),
}
# Singleton indexer instance
_indexer: MemoryIndexer | None = None
def get_memory_indexer() -> MemoryIndexer:
"""Get the singleton memory indexer instance."""
global _indexer
if _indexer is None:
_indexer = MemoryIndexer()
return _indexer