# app/services/memory/episodic/retrieval.py """ Episode Retrieval Strategies. Provides different retrieval strategies for finding relevant episodes: - Semantic similarity (vector search) - Recency-based - Outcome-based filtering - Importance-based ranking """ import logging import time from abc import ABC, abstractmethod from datetime import datetime from enum import Enum from typing import Any from uuid import UUID from sqlalchemy import and_, desc, select from sqlalchemy.ext.asyncio import AsyncSession from app.models.memory.enums import EpisodeOutcome from app.models.memory.episode import Episode as EpisodeModel from app.services.memory.types import Episode, Outcome, RetrievalResult logger = logging.getLogger(__name__) class RetrievalStrategy(str, Enum): """Retrieval strategy types.""" SEMANTIC = "semantic" RECENCY = "recency" OUTCOME = "outcome" IMPORTANCE = "importance" HYBRID = "hybrid" def _model_to_episode(model: EpisodeModel) -> Episode: """Convert SQLAlchemy model to Episode dataclass.""" # SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime # they return actual values. We use type: ignore to handle this mismatch. return Episode( id=model.id, # type: ignore[arg-type] project_id=model.project_id, # type: ignore[arg-type] agent_instance_id=model.agent_instance_id, # type: ignore[arg-type] agent_type_id=model.agent_type_id, # type: ignore[arg-type] session_id=model.session_id, # type: ignore[arg-type] task_type=model.task_type, # type: ignore[arg-type] task_description=model.task_description, # type: ignore[arg-type] actions=model.actions or [], # type: ignore[arg-type] context_summary=model.context_summary, # type: ignore[arg-type] outcome=Outcome(model.outcome.value), outcome_details=model.outcome_details or "", # type: ignore[arg-type] duration_seconds=model.duration_seconds, # type: ignore[arg-type] tokens_used=model.tokens_used, # type: ignore[arg-type] lessons_learned=model.lessons_learned or [], # type: ignore[arg-type] importance_score=model.importance_score, # type: ignore[arg-type] embedding=None, # Don't expose raw embedding occurred_at=model.occurred_at, # type: ignore[arg-type] created_at=model.created_at, # type: ignore[arg-type] updated_at=model.updated_at, # type: ignore[arg-type] ) class BaseRetriever(ABC): """Abstract base class for episode retrieval strategies.""" @abstractmethod async def retrieve( self, session: AsyncSession, project_id: UUID, limit: int = 10, **kwargs: Any, ) -> RetrievalResult[Episode]: """Retrieve episodes based on the strategy.""" ... class RecencyRetriever(BaseRetriever): """Retrieves episodes by recency (most recent first).""" async def retrieve( self, session: AsyncSession, project_id: UUID, limit: int = 10, *, since: datetime | None = None, agent_instance_id: UUID | None = None, **kwargs: Any, ) -> RetrievalResult[Episode]: """Retrieve most recent episodes.""" start_time = time.perf_counter() query = ( select(EpisodeModel) .where(EpisodeModel.project_id == project_id) .order_by(desc(EpisodeModel.occurred_at)) .limit(limit) ) if since is not None: query = query.where(EpisodeModel.occurred_at >= since) if agent_instance_id is not None: query = query.where(EpisodeModel.agent_instance_id == agent_instance_id) result = await session.execute(query) models = list(result.scalars().all()) # Get total count count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id) if since is not None: count_query = count_query.where(EpisodeModel.occurred_at >= since) count_result = await session.execute(count_query) total_count = len(list(count_result.scalars().all())) latency_ms = (time.perf_counter() - start_time) * 1000 return RetrievalResult( items=[_model_to_episode(m) for m in models], total_count=total_count, query="recency", retrieval_type=RetrievalStrategy.RECENCY.value, latency_ms=latency_ms, metadata={"since": since.isoformat() if since else None}, ) class OutcomeRetriever(BaseRetriever): """Retrieves episodes filtered by outcome.""" async def retrieve( self, session: AsyncSession, project_id: UUID, limit: int = 10, *, outcome: Outcome | None = None, agent_instance_id: UUID | None = None, **kwargs: Any, ) -> RetrievalResult[Episode]: """Retrieve episodes by outcome.""" start_time = time.perf_counter() query = ( select(EpisodeModel) .where(EpisodeModel.project_id == project_id) .order_by(desc(EpisodeModel.occurred_at)) .limit(limit) ) if outcome is not None: db_outcome = EpisodeOutcome(outcome.value) query = query.where(EpisodeModel.outcome == db_outcome) if agent_instance_id is not None: query = query.where(EpisodeModel.agent_instance_id == agent_instance_id) result = await session.execute(query) models = list(result.scalars().all()) # Get total count count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id) if outcome is not None: count_query = count_query.where( EpisodeModel.outcome == EpisodeOutcome(outcome.value) ) count_result = await session.execute(count_query) total_count = len(list(count_result.scalars().all())) latency_ms = (time.perf_counter() - start_time) * 1000 return RetrievalResult( items=[_model_to_episode(m) for m in models], total_count=total_count, query=f"outcome:{outcome.value if outcome else 'all'}", retrieval_type=RetrievalStrategy.OUTCOME.value, latency_ms=latency_ms, metadata={"outcome": outcome.value if outcome else None}, ) class TaskTypeRetriever(BaseRetriever): """Retrieves episodes filtered by task type.""" async def retrieve( self, session: AsyncSession, project_id: UUID, limit: int = 10, *, task_type: str | None = None, agent_instance_id: UUID | None = None, **kwargs: Any, ) -> RetrievalResult[Episode]: """Retrieve episodes by task type.""" start_time = time.perf_counter() query = ( select(EpisodeModel) .where(EpisodeModel.project_id == project_id) .order_by(desc(EpisodeModel.occurred_at)) .limit(limit) ) if task_type is not None: query = query.where(EpisodeModel.task_type == task_type) if agent_instance_id is not None: query = query.where(EpisodeModel.agent_instance_id == agent_instance_id) result = await session.execute(query) models = list(result.scalars().all()) # Get total count count_query = select(EpisodeModel).where(EpisodeModel.project_id == project_id) if task_type is not None: count_query = count_query.where(EpisodeModel.task_type == task_type) count_result = await session.execute(count_query) total_count = len(list(count_result.scalars().all())) latency_ms = (time.perf_counter() - start_time) * 1000 return RetrievalResult( items=[_model_to_episode(m) for m in models], total_count=total_count, query=f"task_type:{task_type or 'all'}", retrieval_type="task_type", latency_ms=latency_ms, metadata={"task_type": task_type}, ) class ImportanceRetriever(BaseRetriever): """Retrieves episodes ranked by importance score.""" async def retrieve( self, session: AsyncSession, project_id: UUID, limit: int = 10, *, min_importance: float = 0.0, agent_instance_id: UUID | None = None, **kwargs: Any, ) -> RetrievalResult[Episode]: """Retrieve episodes by importance.""" start_time = time.perf_counter() query = ( select(EpisodeModel) .where( and_( EpisodeModel.project_id == project_id, EpisodeModel.importance_score >= min_importance, ) ) .order_by(desc(EpisodeModel.importance_score)) .limit(limit) ) if agent_instance_id is not None: query = query.where(EpisodeModel.agent_instance_id == agent_instance_id) result = await session.execute(query) models = list(result.scalars().all()) # Get total count count_query = select(EpisodeModel).where( and_( EpisodeModel.project_id == project_id, EpisodeModel.importance_score >= min_importance, ) ) count_result = await session.execute(count_query) total_count = len(list(count_result.scalars().all())) latency_ms = (time.perf_counter() - start_time) * 1000 return RetrievalResult( items=[_model_to_episode(m) for m in models], total_count=total_count, query=f"importance>={min_importance}", retrieval_type=RetrievalStrategy.IMPORTANCE.value, latency_ms=latency_ms, metadata={"min_importance": min_importance}, ) class SemanticRetriever(BaseRetriever): """Retrieves episodes by semantic similarity using vector search.""" def __init__(self, embedding_generator: Any | None = None) -> None: """Initialize with optional embedding generator.""" self._embedding_generator = embedding_generator async def retrieve( self, session: AsyncSession, project_id: UUID, limit: int = 10, *, query_text: str | None = None, query_embedding: list[float] | None = None, agent_instance_id: UUID | None = None, **kwargs: Any, ) -> RetrievalResult[Episode]: """Retrieve episodes by semantic similarity.""" start_time = time.perf_counter() # If no embedding provided, fall back to recency if query_embedding is None and query_text is None: logger.warning( "No query provided for semantic search, falling back to recency" ) recency = RecencyRetriever() fallback_result = await recency.retrieve( session, project_id, limit, agent_instance_id=agent_instance_id ) latency_ms = (time.perf_counter() - start_time) * 1000 return RetrievalResult( items=fallback_result.items, total_count=fallback_result.total_count, query="no_query", retrieval_type=RetrievalStrategy.SEMANTIC.value, latency_ms=latency_ms, metadata={"fallback": "recency", "reason": "no_query"}, ) # Generate embedding if needed embedding = query_embedding if embedding is None and query_text is not None: if self._embedding_generator is not None: embedding = await self._embedding_generator.generate(query_text) else: logger.warning("No embedding generator, falling back to recency") recency = RecencyRetriever() fallback_result = await recency.retrieve( session, project_id, limit, agent_instance_id=agent_instance_id ) latency_ms = (time.perf_counter() - start_time) * 1000 return RetrievalResult( items=fallback_result.items, total_count=fallback_result.total_count, query=query_text, retrieval_type=RetrievalStrategy.SEMANTIC.value, latency_ms=latency_ms, metadata={ "fallback": "recency", "reason": "no_embedding_generator", }, ) # For now, use recency if vector search not available # TODO: Implement proper pgvector similarity search when integrated logger.debug("Vector search not yet implemented, using recency fallback") recency = RecencyRetriever() result = await recency.retrieve( session, project_id, limit, agent_instance_id=agent_instance_id ) latency_ms = (time.perf_counter() - start_time) * 1000 return RetrievalResult( items=result.items, total_count=result.total_count, query=query_text or "embedding", retrieval_type=RetrievalStrategy.SEMANTIC.value, latency_ms=latency_ms, metadata={"fallback": "recency"}, ) class EpisodeRetriever: """ Unified episode retrieval service. Provides a single interface for all retrieval strategies. """ def __init__( self, session: AsyncSession, embedding_generator: Any | None = None, ) -> None: """Initialize retriever with database session.""" self._session = session self._retrievers: dict[RetrievalStrategy, BaseRetriever] = { RetrievalStrategy.RECENCY: RecencyRetriever(), RetrievalStrategy.OUTCOME: OutcomeRetriever(), RetrievalStrategy.IMPORTANCE: ImportanceRetriever(), RetrievalStrategy.SEMANTIC: SemanticRetriever(embedding_generator), } async def retrieve( self, project_id: UUID, strategy: RetrievalStrategy = RetrievalStrategy.RECENCY, limit: int = 10, **kwargs: Any, ) -> RetrievalResult[Episode]: """ Retrieve episodes using the specified strategy. Args: project_id: Project to search within strategy: Retrieval strategy to use limit: Maximum number of episodes to return **kwargs: Strategy-specific parameters Returns: RetrievalResult containing matching episodes """ retriever = self._retrievers.get(strategy) if retriever is None: raise ValueError(f"Unknown retrieval strategy: {strategy}") return await retriever.retrieve(self._session, project_id, limit, **kwargs) async def get_recent( self, project_id: UUID, limit: int = 10, since: datetime | None = None, agent_instance_id: UUID | None = None, ) -> RetrievalResult[Episode]: """Get recent episodes.""" return await self.retrieve( project_id, RetrievalStrategy.RECENCY, limit, since=since, agent_instance_id=agent_instance_id, ) async def get_by_outcome( self, project_id: UUID, outcome: Outcome, limit: int = 10, agent_instance_id: UUID | None = None, ) -> RetrievalResult[Episode]: """Get episodes by outcome.""" return await self.retrieve( project_id, RetrievalStrategy.OUTCOME, limit, outcome=outcome, agent_instance_id=agent_instance_id, ) async def get_by_task_type( self, project_id: UUID, task_type: str, limit: int = 10, agent_instance_id: UUID | None = None, ) -> RetrievalResult[Episode]: """Get episodes by task type.""" retriever = TaskTypeRetriever() return await retriever.retrieve( self._session, project_id, limit, task_type=task_type, agent_instance_id=agent_instance_id, ) async def get_important( self, project_id: UUID, limit: int = 10, min_importance: float = 0.7, agent_instance_id: UUID | None = None, ) -> RetrievalResult[Episode]: """Get high-importance episodes.""" return await self.retrieve( project_id, RetrievalStrategy.IMPORTANCE, limit, min_importance=min_importance, agent_instance_id=agent_instance_id, ) async def search_similar( self, project_id: UUID, query: str, limit: int = 10, agent_instance_id: UUID | None = None, ) -> RetrievalResult[Episode]: """Search for semantically similar episodes.""" return await self.retrieve( project_id, RetrievalStrategy.SEMANTIC, limit, query_text=query, agent_instance_id=agent_instance_id, )