# app/services/memory/consolidation/service.py """ Memory Consolidation Service. Transfers and extracts knowledge between memory tiers: - Working -> Episodic (session end) - Episodic -> Semantic (learn facts) - Episodic -> Procedural (learn procedures) Also handles memory pruning and importance-based retention. """ import logging from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta from typing import Any from uuid import UUID from sqlalchemy.ext.asyncio import AsyncSession from app.services.memory.episodic.memory import EpisodicMemory from app.services.memory.procedural.memory import ProceduralMemory from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor from app.services.memory.semantic.memory import SemanticMemory from app.services.memory.types import ( Episode, EpisodeCreate, Outcome, ProcedureCreate, TaskState, ) from app.services.memory.working.memory import WorkingMemory logger = logging.getLogger(__name__) @dataclass class ConsolidationConfig: """Configuration for memory consolidation.""" # Working -> Episodic thresholds min_steps_for_episode: int = 2 min_duration_seconds: float = 5.0 # Episodic -> Semantic thresholds min_confidence_for_fact: float = 0.6 max_facts_per_episode: int = 10 reinforce_existing_facts: bool = True # Episodic -> Procedural thresholds min_episodes_for_procedure: int = 3 min_success_rate_for_procedure: float = 0.7 min_steps_for_procedure: int = 2 # Pruning thresholds max_episode_age_days: int = 90 min_importance_to_keep: float = 0.2 keep_all_failures: bool = True keep_all_with_lessons: bool = True # Batch sizes batch_size: int = 100 @dataclass class ConsolidationResult: """Result of a consolidation operation.""" source_type: str target_type: str items_processed: int = 0 items_created: int = 0 items_updated: int = 0 items_skipped: int = 0 items_pruned: int = 0 errors: list[str] = field(default_factory=list) duration_seconds: float = 0.0 def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return { "source_type": self.source_type, "target_type": self.target_type, "items_processed": self.items_processed, "items_created": self.items_created, "items_updated": self.items_updated, "items_skipped": self.items_skipped, "items_pruned": self.items_pruned, "errors": self.errors, "duration_seconds": self.duration_seconds, } @dataclass class SessionConsolidationResult: """Result of consolidating a session's working memory to episodic.""" session_id: str episode_created: bool = False episode_id: UUID | None = None scratchpad_entries: int = 0 variables_captured: int = 0 error: str | None = None @dataclass class NightlyConsolidationResult: """Result of nightly consolidation run.""" started_at: datetime completed_at: datetime | None = None episodic_to_semantic: ConsolidationResult | None = None episodic_to_procedural: ConsolidationResult | None = None pruning: ConsolidationResult | None = None total_episodes_processed: int = 0 total_facts_created: int = 0 total_procedures_created: int = 0 total_pruned: int = 0 errors: list[str] = field(default_factory=list) def to_dict(self) -> dict[str, Any]: """Convert to dictionary.""" return { "started_at": self.started_at.isoformat(), "completed_at": self.completed_at.isoformat() if self.completed_at else None, "episodic_to_semantic": ( self.episodic_to_semantic.to_dict() if self.episodic_to_semantic else None ), "episodic_to_procedural": ( self.episodic_to_procedural.to_dict() if self.episodic_to_procedural else None ), "pruning": self.pruning.to_dict() if self.pruning else None, "total_episodes_processed": self.total_episodes_processed, "total_facts_created": self.total_facts_created, "total_procedures_created": self.total_procedures_created, "total_pruned": self.total_pruned, "errors": self.errors, } class MemoryConsolidationService: """ Service for consolidating memories between tiers. Responsibilities: - Transfer working memory to episodic at session end - Extract facts from episodes to semantic memory - Learn procedures from successful episode patterns - Prune old/low-value memories """ def __init__( self, session: AsyncSession, config: ConsolidationConfig | None = None, embedding_generator: Any | None = None, ) -> None: """ Initialize consolidation service. Args: session: Database session config: Consolidation configuration embedding_generator: Optional embedding generator """ self._session = session self._config = config or ConsolidationConfig() self._embedding_generator = embedding_generator self._fact_extractor: FactExtractor = get_fact_extractor() # Memory services (lazy initialized) self._episodic: EpisodicMemory | None = None self._semantic: SemanticMemory | None = None self._procedural: ProceduralMemory | None = None async def _get_episodic(self) -> EpisodicMemory: """Get or create episodic memory service.""" if self._episodic is None: self._episodic = await EpisodicMemory.create( self._session, self._embedding_generator ) return self._episodic async def _get_semantic(self) -> SemanticMemory: """Get or create semantic memory service.""" if self._semantic is None: self._semantic = await SemanticMemory.create( self._session, self._embedding_generator ) return self._semantic async def _get_procedural(self) -> ProceduralMemory: """Get or create procedural memory service.""" if self._procedural is None: self._procedural = await ProceduralMemory.create( self._session, self._embedding_generator ) return self._procedural # ========================================================================= # Working -> Episodic Consolidation # ========================================================================= async def consolidate_session( self, working_memory: WorkingMemory, project_id: UUID, session_id: str, task_type: str = "session_task", agent_instance_id: UUID | None = None, agent_type_id: UUID | None = None, ) -> SessionConsolidationResult: """ Consolidate a session's working memory to episodic memory. Called at session end to transfer relevant session data into a persistent episode. Args: working_memory: The session's working memory project_id: Project ID session_id: Session ID task_type: Type of task performed agent_instance_id: Optional agent instance agent_type_id: Optional agent type Returns: SessionConsolidationResult with outcome details """ result = SessionConsolidationResult(session_id=session_id) try: # Get task state task_state = await working_memory.get_task_state() # Check if there's enough content to consolidate if not self._should_consolidate_session(task_state): logger.debug( f"Skipping consolidation for session {session_id}: insufficient content" ) return result # Gather scratchpad entries scratchpad = await working_memory.get_scratchpad() result.scratchpad_entries = len(scratchpad) # Gather user variables all_data = await working_memory.get_all() result.variables_captured = len(all_data) # Determine outcome outcome = self._determine_session_outcome(task_state) # Build actions from scratchpad and variables actions = self._build_actions_from_session(scratchpad, all_data, task_state) # Build context summary context_summary = self._build_context_summary(task_state, all_data) # Extract lessons learned lessons = self._extract_session_lessons(task_state, outcome) # Calculate importance importance = self._calculate_session_importance( task_state, outcome, actions ) # Create episode episode_data = EpisodeCreate( project_id=project_id, session_id=session_id, task_type=task_type, task_description=task_state.description if task_state else "Session task", actions=actions, context_summary=context_summary, outcome=outcome, outcome_details=task_state.status if task_state else "", duration_seconds=self._calculate_duration(task_state), tokens_used=0, # Would need to track this in working memory lessons_learned=lessons, importance_score=importance, agent_instance_id=agent_instance_id, agent_type_id=agent_type_id, ) episodic = await self._get_episodic() episode = await episodic.record_episode(episode_data) result.episode_created = True result.episode_id = episode.id logger.info( f"Consolidated session {session_id} to episode {episode.id} " f"({len(actions)} actions, outcome={outcome.value})" ) except Exception as e: result.error = str(e) logger.exception(f"Failed to consolidate session {session_id}") return result def _should_consolidate_session(self, task_state: TaskState | None) -> bool: """Check if session has enough content to consolidate.""" if task_state is None: return False # Check minimum steps if task_state.current_step < self._config.min_steps_for_episode: return False return True def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome: """Determine outcome from task state.""" if task_state is None: return Outcome.PARTIAL status = task_state.status.lower() if task_state.status else "" progress = task_state.progress_percent if "success" in status or "complete" in status or progress >= 100: return Outcome.SUCCESS if "fail" in status or "error" in status: return Outcome.FAILURE if progress >= 50: return Outcome.PARTIAL return Outcome.FAILURE def _build_actions_from_session( self, scratchpad: list[str], variables: dict[str, Any], task_state: TaskState | None, ) -> list[dict[str, Any]]: """Build action list from session data.""" actions: list[dict[str, Any]] = [] # Add scratchpad entries as actions for i, entry in enumerate(scratchpad): actions.append( { "step": i + 1, "type": "reasoning", "content": entry[:500], # Truncate long entries } ) # Add final state if task_state: actions.append( { "step": len(scratchpad) + 1, "type": "final_state", "current_step": task_state.current_step, "total_steps": task_state.total_steps, "progress": task_state.progress_percent, "status": task_state.status, } ) return actions def _build_context_summary( self, task_state: TaskState | None, variables: dict[str, Any], ) -> str: """Build context summary from session data.""" parts = [] if task_state: parts.append(f"Task: {task_state.description}") parts.append(f"Progress: {task_state.progress_percent:.1f}%") parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}") # Include key variables key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100} if key_vars: var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5]) parts.append(f"Variables: {var_str}") return "; ".join(parts) if parts else "Session completed" def _extract_session_lessons( self, task_state: TaskState | None, outcome: Outcome, ) -> list[str]: """Extract lessons from session.""" lessons: list[str] = [] if task_state and task_state.status: if outcome == Outcome.FAILURE: lessons.append( f"Task failed at step {task_state.current_step}: {task_state.status}" ) elif outcome == Outcome.SUCCESS: lessons.append( f"Successfully completed in {task_state.current_step} steps" ) return lessons def _calculate_session_importance( self, task_state: TaskState | None, outcome: Outcome, actions: list[dict[str, Any]], ) -> float: """Calculate importance score for session.""" score = 0.5 # Base score # Failures are important to learn from if outcome == Outcome.FAILURE: score += 0.3 # Many steps means complex task if task_state and task_state.total_steps >= 5: score += 0.1 # Many actions means detailed reasoning if len(actions) >= 5: score += 0.1 return min(1.0, score) def _calculate_duration(self, task_state: TaskState | None) -> float: """Calculate session duration.""" if task_state is None: return 0.0 if task_state.started_at and task_state.updated_at: delta = task_state.updated_at - task_state.started_at return delta.total_seconds() return 0.0 # ========================================================================= # Episodic -> Semantic Consolidation # ========================================================================= async def consolidate_episodes_to_facts( self, project_id: UUID, since: datetime | None = None, limit: int | None = None, ) -> ConsolidationResult: """ Extract facts from episodic memories to semantic memory. Args: project_id: Project to consolidate since: Only process episodes since this time limit: Maximum episodes to process Returns: ConsolidationResult with extraction statistics """ start_time = datetime.now(UTC) result = ConsolidationResult( source_type="episodic", target_type="semantic", ) try: episodic = await self._get_episodic() semantic = await self._get_semantic() # Get episodes to process since_time = since or datetime.now(UTC) - timedelta(days=1) episodes = await episodic.get_recent( project_id, limit=limit or self._config.batch_size, since=since_time, ) for episode in episodes: result.items_processed += 1 try: # Extract facts using the extractor extracted_facts = self._fact_extractor.extract_from_episode(episode) for extracted_fact in extracted_facts: if ( extracted_fact.confidence < self._config.min_confidence_for_fact ): result.items_skipped += 1 continue # Create fact (store_fact handles deduplication/reinforcement) fact_create = extracted_fact.to_fact_create( project_id=project_id, source_episode_ids=[episode.id], ) # store_fact automatically reinforces if fact already exists fact = await semantic.store_fact(fact_create) # Check if this was a new fact or reinforced existing if fact.reinforcement_count == 1: result.items_created += 1 else: result.items_updated += 1 except Exception as e: result.errors.append(f"Episode {episode.id}: {e}") logger.warning( f"Failed to extract facts from episode {episode.id}: {e}" ) except Exception as e: result.errors.append(f"Consolidation failed: {e}") logger.exception("Failed episodic -> semantic consolidation") result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds() logger.info( f"Episodic -> Semantic consolidation: " f"{result.items_processed} processed, " f"{result.items_created} created, " f"{result.items_updated} reinforced" ) return result # ========================================================================= # Episodic -> Procedural Consolidation # ========================================================================= async def consolidate_episodes_to_procedures( self, project_id: UUID, agent_type_id: UUID | None = None, since: datetime | None = None, ) -> ConsolidationResult: """ Learn procedures from patterns in episodic memories. Identifies recurring successful patterns and creates/updates procedures to capture them. Args: project_id: Project to consolidate agent_type_id: Optional filter by agent type since: Only process episodes since this time Returns: ConsolidationResult with procedure statistics """ start_time = datetime.now(UTC) result = ConsolidationResult( source_type="episodic", target_type="procedural", ) try: episodic = await self._get_episodic() procedural = await self._get_procedural() # Get successful episodes since_time = since or datetime.now(UTC) - timedelta(days=7) episodes = await episodic.get_by_outcome( project_id, outcome=Outcome.SUCCESS, limit=self._config.batch_size, agent_instance_id=None, # Get all agent instances ) # Group by task type task_groups: dict[str, list[Episode]] = {} for episode in episodes: if episode.occurred_at >= since_time: if episode.task_type not in task_groups: task_groups[episode.task_type] = [] task_groups[episode.task_type].append(episode) result.items_processed = len(episodes) # Process each task type group for task_type, group in task_groups.items(): if len(group) < self._config.min_episodes_for_procedure: result.items_skipped += len(group) continue try: procedure_result = await self._learn_procedure_from_episodes( procedural, project_id, agent_type_id, task_type, group, ) if procedure_result == "created": result.items_created += 1 elif procedure_result == "updated": result.items_updated += 1 else: result.items_skipped += 1 except Exception as e: result.errors.append(f"Task type '{task_type}': {e}") logger.warning(f"Failed to learn procedure for '{task_type}': {e}") except Exception as e: result.errors.append(f"Consolidation failed: {e}") logger.exception("Failed episodic -> procedural consolidation") result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds() logger.info( f"Episodic -> Procedural consolidation: " f"{result.items_processed} processed, " f"{result.items_created} created, " f"{result.items_updated} updated" ) return result async def _learn_procedure_from_episodes( self, procedural: ProceduralMemory, project_id: UUID, agent_type_id: UUID | None, task_type: str, episodes: list[Episode], ) -> str: """Learn or update a procedure from a set of episodes.""" # Calculate success rate for this pattern success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS) total_count = len(episodes) success_rate = success_count / total_count if total_count > 0 else 0 if success_rate < self._config.min_success_rate_for_procedure: return "skipped" # Extract common steps from episodes steps = self._extract_common_steps(episodes) if len(steps) < self._config.min_steps_for_procedure: return "skipped" # Check for existing procedure matching = await procedural.find_matching( context=task_type, project_id=project_id, agent_type_id=agent_type_id, limit=1, ) if matching: # Update existing procedure with new success await procedural.record_outcome( matching[0].id, success=True, ) return "updated" else: # Create new procedure # Note: success_count starts at 1 in record_procedure procedure_data = ProcedureCreate( project_id=project_id, agent_type_id=agent_type_id, name=f"Procedure for {task_type}", trigger_pattern=task_type, steps=steps, ) await procedural.record_procedure(procedure_data) return "created" def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]: """Extract common action steps from multiple episodes.""" # Simple heuristic: take the steps from the most successful episode # with the most detailed actions best_episode = max( episodes, key=lambda e: ( e.outcome == Outcome.SUCCESS, len(e.actions), e.importance_score, ), ) steps: list[dict[str, Any]] = [] for i, action in enumerate(best_episode.actions): step = { "order": i + 1, "action": action.get("type", "action"), "description": action.get("content", str(action))[:500], "parameters": action, } steps.append(step) return steps # ========================================================================= # Memory Pruning # ========================================================================= async def prune_old_episodes( self, project_id: UUID, max_age_days: int | None = None, min_importance: float | None = None, ) -> ConsolidationResult: """ Prune old, low-value episodes. Args: project_id: Project to prune max_age_days: Maximum age in days (default from config) min_importance: Minimum importance to keep (default from config) Returns: ConsolidationResult with pruning statistics """ start_time = datetime.now(UTC) result = ConsolidationResult( source_type="episodic", target_type="pruned", ) max_age = max_age_days or self._config.max_episode_age_days min_imp = min_importance or self._config.min_importance_to_keep cutoff_date = datetime.now(UTC) - timedelta(days=max_age) try: episodic = await self._get_episodic() # Get old episodes # Note: In production, this would use a more efficient query all_episodes = await episodic.get_recent( project_id, limit=self._config.batch_size * 10, since=cutoff_date - timedelta(days=365), # Search past year ) for episode in all_episodes: result.items_processed += 1 # Check if should be pruned if not self._should_prune_episode(episode, cutoff_date, min_imp): result.items_skipped += 1 continue try: deleted = await episodic.delete(episode.id) if deleted: result.items_pruned += 1 else: result.items_skipped += 1 except Exception as e: result.errors.append(f"Episode {episode.id}: {e}") except Exception as e: result.errors.append(f"Pruning failed: {e}") logger.exception("Failed episode pruning") result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds() logger.info( f"Episode pruning: {result.items_processed} processed, " f"{result.items_pruned} pruned" ) return result def _should_prune_episode( self, episode: Episode, cutoff_date: datetime, min_importance: float, ) -> bool: """Determine if an episode should be pruned.""" # Keep recent episodes if episode.occurred_at >= cutoff_date: return False # Keep failures if configured if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE: return False # Keep episodes with lessons if configured if self._config.keep_all_with_lessons and episode.lessons_learned: return False # Keep high-importance episodes if episode.importance_score >= min_importance: return False return True # ========================================================================= # Nightly Consolidation # ========================================================================= async def run_nightly_consolidation( self, project_id: UUID, agent_type_id: UUID | None = None, ) -> NightlyConsolidationResult: """ Run full nightly consolidation workflow. This includes: 1. Extract facts from recent episodes 2. Learn procedures from successful patterns 3. Prune old, low-value memories Args: project_id: Project to consolidate agent_type_id: Optional agent type filter Returns: NightlyConsolidationResult with all outcomes """ result = NightlyConsolidationResult(started_at=datetime.now(UTC)) logger.info(f"Starting nightly consolidation for project {project_id}") try: # Step 1: Episodic -> Semantic (last 24 hours) since_yesterday = datetime.now(UTC) - timedelta(days=1) result.episodic_to_semantic = await self.consolidate_episodes_to_facts( project_id=project_id, since=since_yesterday, ) result.total_facts_created = result.episodic_to_semantic.items_created # Step 2: Episodic -> Procedural (last 7 days) since_week = datetime.now(UTC) - timedelta(days=7) result.episodic_to_procedural = ( await self.consolidate_episodes_to_procedures( project_id=project_id, agent_type_id=agent_type_id, since=since_week, ) ) result.total_procedures_created = ( result.episodic_to_procedural.items_created ) # Step 3: Prune old memories result.pruning = await self.prune_old_episodes(project_id=project_id) result.total_pruned = result.pruning.items_pruned # Calculate totals result.total_episodes_processed = ( result.episodic_to_semantic.items_processed if result.episodic_to_semantic else 0 ) + ( result.episodic_to_procedural.items_processed if result.episodic_to_procedural else 0 ) # Collect all errors if result.episodic_to_semantic and result.episodic_to_semantic.errors: result.errors.extend(result.episodic_to_semantic.errors) if result.episodic_to_procedural and result.episodic_to_procedural.errors: result.errors.extend(result.episodic_to_procedural.errors) if result.pruning and result.pruning.errors: result.errors.extend(result.pruning.errors) except Exception as e: result.errors.append(f"Nightly consolidation failed: {e}") logger.exception("Nightly consolidation failed") result.completed_at = datetime.now(UTC) duration = (result.completed_at - result.started_at).total_seconds() logger.info( f"Nightly consolidation completed in {duration:.1f}s: " f"{result.total_facts_created} facts, " f"{result.total_procedures_created} procedures, " f"{result.total_pruned} pruned" ) return result # Singleton instance _consolidation_service: MemoryConsolidationService | None = None async def get_consolidation_service( session: AsyncSession, config: ConsolidationConfig | None = None, ) -> MemoryConsolidationService: """ Get or create the memory consolidation service. Args: session: Database session config: Optional configuration Returns: MemoryConsolidationService instance """ global _consolidation_service if _consolidation_service is None: _consolidation_service = MemoryConsolidationService( session=session, config=config ) return _consolidation_service