diff --git a/backend/app/services/memory/consolidation/__init__.py b/backend/app/services/memory/consolidation/__init__.py index 0e323d8..7b44f0a 100644 --- a/backend/app/services/memory/consolidation/__init__.py +++ b/backend/app/services/memory/consolidation/__init__.py @@ -1,10 +1,29 @@ +# app/services/memory/consolidation/__init__.py """ -Memory Consolidation +Memory Consolidation. Transfers and extracts knowledge between memory tiers: -- Working -> Episodic -- Episodic -> Semantic -- Episodic -> Procedural +- Working -> Episodic (session end) +- Episodic -> Semantic (learn facts) +- Episodic -> Procedural (learn procedures) + +Also handles memory pruning and importance-based retention. """ -# Will be populated in #95 +from .service import ( + ConsolidationConfig, + ConsolidationResult, + MemoryConsolidationService, + NightlyConsolidationResult, + SessionConsolidationResult, + get_consolidation_service, +) + +__all__ = [ + "ConsolidationConfig", + "ConsolidationResult", + "MemoryConsolidationService", + "NightlyConsolidationResult", + "SessionConsolidationResult", + "get_consolidation_service", +] diff --git a/backend/app/services/memory/consolidation/service.py b/backend/app/services/memory/consolidation/service.py new file mode 100644 index 0000000..58615e7 --- /dev/null +++ b/backend/app/services/memory/consolidation/service.py @@ -0,0 +1,918 @@ +# app/services/memory/consolidation/service.py +""" +Memory Consolidation Service. + +Transfers and extracts knowledge between memory tiers: +- Working -> Episodic (session end) +- Episodic -> Semantic (learn facts) +- Episodic -> Procedural (learn procedures) + +Also handles memory pruning and importance-based retention. +""" + +import logging +from dataclasses import dataclass, field +from datetime import UTC, datetime, timedelta +from typing import Any +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.memory.episodic.memory import EpisodicMemory +from app.services.memory.procedural.memory import ProceduralMemory +from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor +from app.services.memory.semantic.memory import SemanticMemory +from app.services.memory.types import ( + Episode, + EpisodeCreate, + Outcome, + ProcedureCreate, + TaskState, +) +from app.services.memory.working.memory import WorkingMemory + +logger = logging.getLogger(__name__) + + +@dataclass +class ConsolidationConfig: + """Configuration for memory consolidation.""" + + # Working -> Episodic thresholds + min_steps_for_episode: int = 2 + min_duration_seconds: float = 5.0 + + # Episodic -> Semantic thresholds + min_confidence_for_fact: float = 0.6 + max_facts_per_episode: int = 10 + reinforce_existing_facts: bool = True + + # Episodic -> Procedural thresholds + min_episodes_for_procedure: int = 3 + min_success_rate_for_procedure: float = 0.7 + min_steps_for_procedure: int = 2 + + # Pruning thresholds + max_episode_age_days: int = 90 + min_importance_to_keep: float = 0.2 + keep_all_failures: bool = True + keep_all_with_lessons: bool = True + + # Batch sizes + batch_size: int = 100 + + +@dataclass +class ConsolidationResult: + """Result of a consolidation operation.""" + + source_type: str + target_type: str + items_processed: int = 0 + items_created: int = 0 + items_updated: int = 0 + items_skipped: int = 0 + items_pruned: int = 0 + errors: list[str] = field(default_factory=list) + duration_seconds: float = 0.0 + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "source_type": self.source_type, + "target_type": self.target_type, + "items_processed": self.items_processed, + "items_created": self.items_created, + "items_updated": self.items_updated, + "items_skipped": self.items_skipped, + "items_pruned": self.items_pruned, + "errors": self.errors, + "duration_seconds": self.duration_seconds, + } + + +@dataclass +class SessionConsolidationResult: + """Result of consolidating a session's working memory to episodic.""" + + session_id: str + episode_created: bool = False + episode_id: UUID | None = None + scratchpad_entries: int = 0 + variables_captured: int = 0 + error: str | None = None + + +@dataclass +class NightlyConsolidationResult: + """Result of nightly consolidation run.""" + + started_at: datetime + completed_at: datetime | None = None + episodic_to_semantic: ConsolidationResult | None = None + episodic_to_procedural: ConsolidationResult | None = None + pruning: ConsolidationResult | None = None + total_episodes_processed: int = 0 + total_facts_created: int = 0 + total_procedures_created: int = 0 + total_pruned: int = 0 + errors: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "started_at": self.started_at.isoformat(), + "completed_at": self.completed_at.isoformat() + if self.completed_at + else None, + "episodic_to_semantic": ( + self.episodic_to_semantic.to_dict() + if self.episodic_to_semantic + else None + ), + "episodic_to_procedural": ( + self.episodic_to_procedural.to_dict() + if self.episodic_to_procedural + else None + ), + "pruning": self.pruning.to_dict() if self.pruning else None, + "total_episodes_processed": self.total_episodes_processed, + "total_facts_created": self.total_facts_created, + "total_procedures_created": self.total_procedures_created, + "total_pruned": self.total_pruned, + "errors": self.errors, + } + + +class MemoryConsolidationService: + """ + Service for consolidating memories between tiers. + + Responsibilities: + - Transfer working memory to episodic at session end + - Extract facts from episodes to semantic memory + - Learn procedures from successful episode patterns + - Prune old/low-value memories + """ + + def __init__( + self, + session: AsyncSession, + config: ConsolidationConfig | None = None, + embedding_generator: Any | None = None, + ) -> None: + """ + Initialize consolidation service. + + Args: + session: Database session + config: Consolidation configuration + embedding_generator: Optional embedding generator + """ + self._session = session + self._config = config or ConsolidationConfig() + self._embedding_generator = embedding_generator + self._fact_extractor: FactExtractor = get_fact_extractor() + + # Memory services (lazy initialized) + self._episodic: EpisodicMemory | None = None + self._semantic: SemanticMemory | None = None + self._procedural: ProceduralMemory | None = None + + async def _get_episodic(self) -> EpisodicMemory: + """Get or create episodic memory service.""" + if self._episodic is None: + self._episodic = await EpisodicMemory.create( + self._session, self._embedding_generator + ) + return self._episodic + + async def _get_semantic(self) -> SemanticMemory: + """Get or create semantic memory service.""" + if self._semantic is None: + self._semantic = await SemanticMemory.create( + self._session, self._embedding_generator + ) + return self._semantic + + async def _get_procedural(self) -> ProceduralMemory: + """Get or create procedural memory service.""" + if self._procedural is None: + self._procedural = await ProceduralMemory.create( + self._session, self._embedding_generator + ) + return self._procedural + + # ========================================================================= + # Working -> Episodic Consolidation + # ========================================================================= + + async def consolidate_session( + self, + working_memory: WorkingMemory, + project_id: UUID, + session_id: str, + task_type: str = "session_task", + agent_instance_id: UUID | None = None, + agent_type_id: UUID | None = None, + ) -> SessionConsolidationResult: + """ + Consolidate a session's working memory to episodic memory. + + Called at session end to transfer relevant session data + into a persistent episode. + + Args: + working_memory: The session's working memory + project_id: Project ID + session_id: Session ID + task_type: Type of task performed + agent_instance_id: Optional agent instance + agent_type_id: Optional agent type + + Returns: + SessionConsolidationResult with outcome details + """ + result = SessionConsolidationResult(session_id=session_id) + + try: + # Get task state + task_state = await working_memory.get_task_state() + + # Check if there's enough content to consolidate + if not self._should_consolidate_session(task_state): + logger.debug( + f"Skipping consolidation for session {session_id}: insufficient content" + ) + return result + + # Gather scratchpad entries + scratchpad = await working_memory.get_scratchpad() + result.scratchpad_entries = len(scratchpad) + + # Gather user variables + all_data = await working_memory.get_all() + result.variables_captured = len(all_data) + + # Determine outcome + outcome = self._determine_session_outcome(task_state) + + # Build actions from scratchpad and variables + actions = self._build_actions_from_session(scratchpad, all_data, task_state) + + # Build context summary + context_summary = self._build_context_summary(task_state, all_data) + + # Extract lessons learned + lessons = self._extract_session_lessons(task_state, outcome) + + # Calculate importance + importance = self._calculate_session_importance( + task_state, outcome, actions + ) + + # Create episode + episode_data = EpisodeCreate( + project_id=project_id, + session_id=session_id, + task_type=task_type, + task_description=task_state.description + if task_state + else "Session task", + actions=actions, + context_summary=context_summary, + outcome=outcome, + outcome_details=task_state.status if task_state else "", + duration_seconds=self._calculate_duration(task_state), + tokens_used=0, # Would need to track this in working memory + lessons_learned=lessons, + importance_score=importance, + agent_instance_id=agent_instance_id, + agent_type_id=agent_type_id, + ) + + episodic = await self._get_episodic() + episode = await episodic.record_episode(episode_data) + + result.episode_created = True + result.episode_id = episode.id + + logger.info( + f"Consolidated session {session_id} to episode {episode.id} " + f"({len(actions)} actions, outcome={outcome.value})" + ) + + except Exception as e: + result.error = str(e) + logger.exception(f"Failed to consolidate session {session_id}") + + return result + + def _should_consolidate_session(self, task_state: TaskState | None) -> bool: + """Check if session has enough content to consolidate.""" + if task_state is None: + return False + + # Check minimum steps + if task_state.current_step < self._config.min_steps_for_episode: + return False + + return True + + def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome: + """Determine outcome from task state.""" + if task_state is None: + return Outcome.PARTIAL + + status = task_state.status.lower() if task_state.status else "" + progress = task_state.progress_percent + + if "success" in status or "complete" in status or progress >= 100: + return Outcome.SUCCESS + if "fail" in status or "error" in status: + return Outcome.FAILURE + if progress >= 50: + return Outcome.PARTIAL + return Outcome.FAILURE + + def _build_actions_from_session( + self, + scratchpad: list[str], + variables: dict[str, Any], + task_state: TaskState | None, + ) -> list[dict[str, Any]]: + """Build action list from session data.""" + actions: list[dict[str, Any]] = [] + + # Add scratchpad entries as actions + for i, entry in enumerate(scratchpad): + actions.append( + { + "step": i + 1, + "type": "reasoning", + "content": entry[:500], # Truncate long entries + } + ) + + # Add final state + if task_state: + actions.append( + { + "step": len(scratchpad) + 1, + "type": "final_state", + "current_step": task_state.current_step, + "total_steps": task_state.total_steps, + "progress": task_state.progress_percent, + "status": task_state.status, + } + ) + + return actions + + def _build_context_summary( + self, + task_state: TaskState | None, + variables: dict[str, Any], + ) -> str: + """Build context summary from session data.""" + parts = [] + + if task_state: + parts.append(f"Task: {task_state.description}") + parts.append(f"Progress: {task_state.progress_percent:.1f}%") + parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}") + + # Include key variables + key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100} + if key_vars: + var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5]) + parts.append(f"Variables: {var_str}") + + return "; ".join(parts) if parts else "Session completed" + + def _extract_session_lessons( + self, + task_state: TaskState | None, + outcome: Outcome, + ) -> list[str]: + """Extract lessons from session.""" + lessons: list[str] = [] + + if task_state and task_state.status: + if outcome == Outcome.FAILURE: + lessons.append( + f"Task failed at step {task_state.current_step}: {task_state.status}" + ) + elif outcome == Outcome.SUCCESS: + lessons.append( + f"Successfully completed in {task_state.current_step} steps" + ) + + return lessons + + def _calculate_session_importance( + self, + task_state: TaskState | None, + outcome: Outcome, + actions: list[dict[str, Any]], + ) -> float: + """Calculate importance score for session.""" + score = 0.5 # Base score + + # Failures are important to learn from + if outcome == Outcome.FAILURE: + score += 0.3 + + # Many steps means complex task + if task_state and task_state.total_steps >= 5: + score += 0.1 + + # Many actions means detailed reasoning + if len(actions) >= 5: + score += 0.1 + + return min(1.0, score) + + def _calculate_duration(self, task_state: TaskState | None) -> float: + """Calculate session duration.""" + if task_state is None: + return 0.0 + + if task_state.started_at and task_state.updated_at: + delta = task_state.updated_at - task_state.started_at + return delta.total_seconds() + + return 0.0 + + # ========================================================================= + # Episodic -> Semantic Consolidation + # ========================================================================= + + async def consolidate_episodes_to_facts( + self, + project_id: UUID, + since: datetime | None = None, + limit: int | None = None, + ) -> ConsolidationResult: + """ + Extract facts from episodic memories to semantic memory. + + Args: + project_id: Project to consolidate + since: Only process episodes since this time + limit: Maximum episodes to process + + Returns: + ConsolidationResult with extraction statistics + """ + start_time = datetime.now(UTC) + result = ConsolidationResult( + source_type="episodic", + target_type="semantic", + ) + + try: + episodic = await self._get_episodic() + semantic = await self._get_semantic() + + # Get episodes to process + since_time = since or datetime.now(UTC) - timedelta(days=1) + episodes = await episodic.get_recent( + project_id, + limit=limit or self._config.batch_size, + since=since_time, + ) + + for episode in episodes: + result.items_processed += 1 + + try: + # Extract facts using the extractor + extracted_facts = self._fact_extractor.extract_from_episode(episode) + + for extracted_fact in extracted_facts: + if ( + extracted_fact.confidence + < self._config.min_confidence_for_fact + ): + result.items_skipped += 1 + continue + + # Create fact (store_fact handles deduplication/reinforcement) + fact_create = extracted_fact.to_fact_create( + project_id=project_id, + source_episode_ids=[episode.id], + ) + + # store_fact automatically reinforces if fact already exists + fact = await semantic.store_fact(fact_create) + + # Check if this was a new fact or reinforced existing + if fact.reinforcement_count == 1: + result.items_created += 1 + else: + result.items_updated += 1 + + except Exception as e: + result.errors.append(f"Episode {episode.id}: {e}") + logger.warning( + f"Failed to extract facts from episode {episode.id}: {e}" + ) + + except Exception as e: + result.errors.append(f"Consolidation failed: {e}") + logger.exception("Failed episodic -> semantic consolidation") + + result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds() + + logger.info( + f"Episodic -> Semantic consolidation: " + f"{result.items_processed} processed, " + f"{result.items_created} created, " + f"{result.items_updated} reinforced" + ) + + return result + + # ========================================================================= + # Episodic -> Procedural Consolidation + # ========================================================================= + + async def consolidate_episodes_to_procedures( + self, + project_id: UUID, + agent_type_id: UUID | None = None, + since: datetime | None = None, + ) -> ConsolidationResult: + """ + Learn procedures from patterns in episodic memories. + + Identifies recurring successful patterns and creates/updates + procedures to capture them. + + Args: + project_id: Project to consolidate + agent_type_id: Optional filter by agent type + since: Only process episodes since this time + + Returns: + ConsolidationResult with procedure statistics + """ + start_time = datetime.now(UTC) + result = ConsolidationResult( + source_type="episodic", + target_type="procedural", + ) + + try: + episodic = await self._get_episodic() + procedural = await self._get_procedural() + + # Get successful episodes + since_time = since or datetime.now(UTC) - timedelta(days=7) + episodes = await episodic.get_by_outcome( + project_id, + outcome=Outcome.SUCCESS, + limit=self._config.batch_size, + agent_instance_id=None, # Get all agent instances + ) + + # Group by task type + task_groups: dict[str, list[Episode]] = {} + for episode in episodes: + if episode.occurred_at >= since_time: + if episode.task_type not in task_groups: + task_groups[episode.task_type] = [] + task_groups[episode.task_type].append(episode) + + result.items_processed = len(episodes) + + # Process each task type group + for task_type, group in task_groups.items(): + if len(group) < self._config.min_episodes_for_procedure: + result.items_skipped += len(group) + continue + + try: + procedure_result = await self._learn_procedure_from_episodes( + procedural, + project_id, + agent_type_id, + task_type, + group, + ) + + if procedure_result == "created": + result.items_created += 1 + elif procedure_result == "updated": + result.items_updated += 1 + else: + result.items_skipped += 1 + + except Exception as e: + result.errors.append(f"Task type '{task_type}': {e}") + logger.warning(f"Failed to learn procedure for '{task_type}': {e}") + + except Exception as e: + result.errors.append(f"Consolidation failed: {e}") + logger.exception("Failed episodic -> procedural consolidation") + + result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds() + + logger.info( + f"Episodic -> Procedural consolidation: " + f"{result.items_processed} processed, " + f"{result.items_created} created, " + f"{result.items_updated} updated" + ) + + return result + + async def _learn_procedure_from_episodes( + self, + procedural: ProceduralMemory, + project_id: UUID, + agent_type_id: UUID | None, + task_type: str, + episodes: list[Episode], + ) -> str: + """Learn or update a procedure from a set of episodes.""" + # Calculate success rate for this pattern + success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS) + total_count = len(episodes) + success_rate = success_count / total_count if total_count > 0 else 0 + + if success_rate < self._config.min_success_rate_for_procedure: + return "skipped" + + # Extract common steps from episodes + steps = self._extract_common_steps(episodes) + + if len(steps) < self._config.min_steps_for_procedure: + return "skipped" + + # Check for existing procedure + matching = await procedural.find_matching( + context=task_type, + project_id=project_id, + agent_type_id=agent_type_id, + limit=1, + ) + + if matching: + # Update existing procedure with new success + await procedural.record_outcome( + matching[0].id, + success=True, + ) + return "updated" + else: + # Create new procedure + # Note: success_count starts at 1 in record_procedure + procedure_data = ProcedureCreate( + project_id=project_id, + agent_type_id=agent_type_id, + name=f"Procedure for {task_type}", + trigger_pattern=task_type, + steps=steps, + ) + await procedural.record_procedure(procedure_data) + return "created" + + def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]: + """Extract common action steps from multiple episodes.""" + # Simple heuristic: take the steps from the most successful episode + # with the most detailed actions + + best_episode = max( + episodes, + key=lambda e: ( + e.outcome == Outcome.SUCCESS, + len(e.actions), + e.importance_score, + ), + ) + + steps: list[dict[str, Any]] = [] + for i, action in enumerate(best_episode.actions): + step = { + "order": i + 1, + "action": action.get("type", "action"), + "description": action.get("content", str(action))[:500], + "parameters": action, + } + steps.append(step) + + return steps + + # ========================================================================= + # Memory Pruning + # ========================================================================= + + async def prune_old_episodes( + self, + project_id: UUID, + max_age_days: int | None = None, + min_importance: float | None = None, + ) -> ConsolidationResult: + """ + Prune old, low-value episodes. + + Args: + project_id: Project to prune + max_age_days: Maximum age in days (default from config) + min_importance: Minimum importance to keep (default from config) + + Returns: + ConsolidationResult with pruning statistics + """ + start_time = datetime.now(UTC) + result = ConsolidationResult( + source_type="episodic", + target_type="pruned", + ) + + max_age = max_age_days or self._config.max_episode_age_days + min_imp = min_importance or self._config.min_importance_to_keep + cutoff_date = datetime.now(UTC) - timedelta(days=max_age) + + try: + episodic = await self._get_episodic() + + # Get old episodes + # Note: In production, this would use a more efficient query + all_episodes = await episodic.get_recent( + project_id, + limit=self._config.batch_size * 10, + since=cutoff_date - timedelta(days=365), # Search past year + ) + + for episode in all_episodes: + result.items_processed += 1 + + # Check if should be pruned + if not self._should_prune_episode(episode, cutoff_date, min_imp): + result.items_skipped += 1 + continue + + try: + deleted = await episodic.delete(episode.id) + if deleted: + result.items_pruned += 1 + else: + result.items_skipped += 1 + except Exception as e: + result.errors.append(f"Episode {episode.id}: {e}") + + except Exception as e: + result.errors.append(f"Pruning failed: {e}") + logger.exception("Failed episode pruning") + + result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds() + + logger.info( + f"Episode pruning: {result.items_processed} processed, " + f"{result.items_pruned} pruned" + ) + + return result + + def _should_prune_episode( + self, + episode: Episode, + cutoff_date: datetime, + min_importance: float, + ) -> bool: + """Determine if an episode should be pruned.""" + # Keep recent episodes + if episode.occurred_at >= cutoff_date: + return False + + # Keep failures if configured + if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE: + return False + + # Keep episodes with lessons if configured + if self._config.keep_all_with_lessons and episode.lessons_learned: + return False + + # Keep high-importance episodes + if episode.importance_score >= min_importance: + return False + + return True + + # ========================================================================= + # Nightly Consolidation + # ========================================================================= + + async def run_nightly_consolidation( + self, + project_id: UUID, + agent_type_id: UUID | None = None, + ) -> NightlyConsolidationResult: + """ + Run full nightly consolidation workflow. + + This includes: + 1. Extract facts from recent episodes + 2. Learn procedures from successful patterns + 3. Prune old, low-value memories + + Args: + project_id: Project to consolidate + agent_type_id: Optional agent type filter + + Returns: + NightlyConsolidationResult with all outcomes + """ + result = NightlyConsolidationResult(started_at=datetime.now(UTC)) + + logger.info(f"Starting nightly consolidation for project {project_id}") + + try: + # Step 1: Episodic -> Semantic (last 24 hours) + since_yesterday = datetime.now(UTC) - timedelta(days=1) + result.episodic_to_semantic = await self.consolidate_episodes_to_facts( + project_id=project_id, + since=since_yesterday, + ) + result.total_facts_created = result.episodic_to_semantic.items_created + + # Step 2: Episodic -> Procedural (last 7 days) + since_week = datetime.now(UTC) - timedelta(days=7) + result.episodic_to_procedural = ( + await self.consolidate_episodes_to_procedures( + project_id=project_id, + agent_type_id=agent_type_id, + since=since_week, + ) + ) + result.total_procedures_created = ( + result.episodic_to_procedural.items_created + ) + + # Step 3: Prune old memories + result.pruning = await self.prune_old_episodes(project_id=project_id) + result.total_pruned = result.pruning.items_pruned + + # Calculate totals + result.total_episodes_processed = ( + result.episodic_to_semantic.items_processed + if result.episodic_to_semantic + else 0 + ) + ( + result.episodic_to_procedural.items_processed + if result.episodic_to_procedural + else 0 + ) + + # Collect all errors + if result.episodic_to_semantic and result.episodic_to_semantic.errors: + result.errors.extend(result.episodic_to_semantic.errors) + if result.episodic_to_procedural and result.episodic_to_procedural.errors: + result.errors.extend(result.episodic_to_procedural.errors) + if result.pruning and result.pruning.errors: + result.errors.extend(result.pruning.errors) + + except Exception as e: + result.errors.append(f"Nightly consolidation failed: {e}") + logger.exception("Nightly consolidation failed") + + result.completed_at = datetime.now(UTC) + + duration = (result.completed_at - result.started_at).total_seconds() + logger.info( + f"Nightly consolidation completed in {duration:.1f}s: " + f"{result.total_facts_created} facts, " + f"{result.total_procedures_created} procedures, " + f"{result.total_pruned} pruned" + ) + + return result + + +# Singleton instance +_consolidation_service: MemoryConsolidationService | None = None + + +async def get_consolidation_service( + session: AsyncSession, + config: ConsolidationConfig | None = None, +) -> MemoryConsolidationService: + """ + Get or create the memory consolidation service. + + Args: + session: Database session + config: Optional configuration + + Returns: + MemoryConsolidationService instance + """ + global _consolidation_service + if _consolidation_service is None: + _consolidation_service = MemoryConsolidationService( + session=session, config=config + ) + return _consolidation_service diff --git a/backend/app/tasks/__init__.py b/backend/app/tasks/__init__.py index 8cf9c2f..a121bee 100644 --- a/backend/app/tasks/__init__.py +++ b/backend/app/tasks/__init__.py @@ -10,14 +10,16 @@ Modules: sync: Issue synchronization tasks (incremental/full sync, webhooks) workflow: Workflow state management tasks cost: Cost tracking and budget monitoring tasks + memory_consolidation: Memory consolidation tasks """ -from app.tasks import agent, cost, git, sync, workflow +from app.tasks import agent, cost, git, memory_consolidation, sync, workflow __all__ = [ "agent", "cost", "git", + "memory_consolidation", "sync", "workflow", ] diff --git a/backend/app/tasks/memory_consolidation.py b/backend/app/tasks/memory_consolidation.py new file mode 100644 index 0000000..9779aae --- /dev/null +++ b/backend/app/tasks/memory_consolidation.py @@ -0,0 +1,234 @@ +# app/tasks/memory_consolidation.py +""" +Memory consolidation Celery tasks. + +Handles scheduled and on-demand memory consolidation: +- Session consolidation (on session end) +- Nightly consolidation (scheduled) +- On-demand project consolidation +""" + +import logging +from typing import Any + +from app.celery_app import celery_app + +logger = logging.getLogger(__name__) + + +@celery_app.task( + bind=True, + name="app.tasks.memory_consolidation.consolidate_session", + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 3}, +) +def consolidate_session( + self, + project_id: str, + session_id: str, + task_type: str = "session_task", + agent_instance_id: str | None = None, + agent_type_id: str | None = None, +) -> dict[str, Any]: + """ + Consolidate a session's working memory to episodic memory. + + This task is triggered when an agent session ends to transfer + relevant session data into persistent episodic memory. + + Args: + project_id: UUID of the project + session_id: Session identifier + task_type: Type of task performed + agent_instance_id: Optional agent instance UUID + agent_type_id: Optional agent type UUID + + Returns: + dict with consolidation results + """ + logger.info(f"Consolidating session {session_id} for project {project_id}") + + # TODO: Implement actual consolidation + # This will involve: + # 1. Getting database session from async context + # 2. Loading working memory for session + # 3. Calling consolidation service + # 4. Returning results + + # Placeholder implementation + return { + "status": "pending", + "project_id": project_id, + "session_id": session_id, + "episode_created": False, + } + + +@celery_app.task( + bind=True, + name="app.tasks.memory_consolidation.run_nightly_consolidation", + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 3}, +) +def run_nightly_consolidation( + self, + project_id: str, + agent_type_id: str | None = None, +) -> dict[str, Any]: + """ + Run nightly memory consolidation for a project. + + This task performs the full consolidation workflow: + 1. Extract facts from recent episodes to semantic memory + 2. Learn procedures from successful episode patterns + 3. Prune old, low-value memories + + Args: + project_id: UUID of the project to consolidate + agent_type_id: Optional agent type to filter by + + Returns: + dict with consolidation results + """ + logger.info(f"Running nightly consolidation for project {project_id}") + + # TODO: Implement actual consolidation + # This will involve: + # 1. Getting database session from async context + # 2. Creating consolidation service instance + # 3. Running run_nightly_consolidation + # 4. Returning results + + # Placeholder implementation + return { + "status": "pending", + "project_id": project_id, + "total_facts_created": 0, + "total_procedures_created": 0, + "total_pruned": 0, + } + + +@celery_app.task( + bind=True, + name="app.tasks.memory_consolidation.consolidate_episodes_to_facts", + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 3}, +) +def consolidate_episodes_to_facts( + self, + project_id: str, + since_hours: int = 24, + limit: int | None = None, +) -> dict[str, Any]: + """ + Extract facts from episodic memories. + + Args: + project_id: UUID of the project + since_hours: Process episodes from last N hours + limit: Maximum episodes to process + + Returns: + dict with extraction results + """ + logger.info(f"Consolidating episodes to facts for project {project_id}") + + # TODO: Implement actual consolidation + # Placeholder implementation + return { + "status": "pending", + "project_id": project_id, + "items_processed": 0, + "items_created": 0, + } + + +@celery_app.task( + bind=True, + name="app.tasks.memory_consolidation.consolidate_episodes_to_procedures", + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 3}, +) +def consolidate_episodes_to_procedures( + self, + project_id: str, + agent_type_id: str | None = None, + since_days: int = 7, +) -> dict[str, Any]: + """ + Learn procedures from episodic patterns. + + Args: + project_id: UUID of the project + agent_type_id: Optional agent type filter + since_days: Process episodes from last N days + + Returns: + dict with procedure learning results + """ + logger.info(f"Consolidating episodes to procedures for project {project_id}") + + # TODO: Implement actual consolidation + # Placeholder implementation + return { + "status": "pending", + "project_id": project_id, + "items_processed": 0, + "items_created": 0, + } + + +@celery_app.task( + bind=True, + name="app.tasks.memory_consolidation.prune_old_memories", + autoretry_for=(Exception,), + retry_backoff=True, + retry_kwargs={"max_retries": 3}, +) +def prune_old_memories( + self, + project_id: str, + max_age_days: int = 90, + min_importance: float = 0.2, +) -> dict[str, Any]: + """ + Prune old, low-value memories. + + Args: + project_id: UUID of the project + max_age_days: Maximum age in days + min_importance: Minimum importance to keep + + Returns: + dict with pruning results + """ + logger.info(f"Pruning old memories for project {project_id}") + + # TODO: Implement actual pruning + # Placeholder implementation + return { + "status": "pending", + "project_id": project_id, + "items_pruned": 0, + } + + +# ========================================================================= +# Celery Beat Schedule Configuration +# ========================================================================= + +# This would typically be configured in celery_app.py or a separate config file +# Example schedule for nightly consolidation: +# +# app.conf.beat_schedule = { +# 'nightly-memory-consolidation': { +# 'task': 'app.tasks.memory_consolidation.run_nightly_consolidation', +# 'schedule': crontab(hour=2, minute=0), # 2 AM daily +# 'args': (None,), # Will process all projects +# }, +# } diff --git a/backend/tests/unit/services/memory/consolidation/__init__.py b/backend/tests/unit/services/memory/consolidation/__init__.py new file mode 100644 index 0000000..11083d1 --- /dev/null +++ b/backend/tests/unit/services/memory/consolidation/__init__.py @@ -0,0 +1,2 @@ +# tests/unit/services/memory/consolidation/__init__.py +"""Tests for memory consolidation.""" diff --git a/backend/tests/unit/services/memory/consolidation/test_service.py b/backend/tests/unit/services/memory/consolidation/test_service.py new file mode 100644 index 0000000..934a2e9 --- /dev/null +++ b/backend/tests/unit/services/memory/consolidation/test_service.py @@ -0,0 +1,736 @@ +# tests/unit/services/memory/consolidation/test_service.py +"""Unit tests for memory consolidation service.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest + +from app.services.memory.consolidation.service import ( + ConsolidationConfig, + ConsolidationResult, + MemoryConsolidationService, + NightlyConsolidationResult, + SessionConsolidationResult, +) +from app.services.memory.types import Episode, Outcome, TaskState + + +def _utcnow() -> datetime: + """Get current UTC time.""" + return datetime.now(UTC) + + +def make_episode( + outcome: Outcome = Outcome.SUCCESS, + occurred_at: datetime | None = None, + task_type: str = "test_task", + lessons_learned: list[str] | None = None, + importance_score: float = 0.5, + actions: list[dict] | None = None, +) -> Episode: + """Create a test episode.""" + return Episode( + id=uuid4(), + project_id=uuid4(), + agent_instance_id=uuid4(), + agent_type_id=uuid4(), + session_id="test-session", + task_type=task_type, + task_description="Test task description", + actions=actions or [{"action": "test"}], + context_summary="Test context", + outcome=outcome, + outcome_details="Test outcome", + duration_seconds=10.0, + tokens_used=100, + lessons_learned=lessons_learned or [], + importance_score=importance_score, + embedding=None, + occurred_at=occurred_at or _utcnow(), + created_at=_utcnow(), + updated_at=_utcnow(), + ) + + +def make_task_state( + current_step: int = 5, + total_steps: int = 10, + progress_percent: float = 50.0, + status: str = "in_progress", + description: str = "Test Task", +) -> TaskState: + """Create a test task state.""" + now = _utcnow() + return TaskState( + task_id="test-task-id", + task_type="test_task", + description=description, + current_step=current_step, + total_steps=total_steps, + status=status, + progress_percent=progress_percent, + started_at=now - timedelta(hours=1), + updated_at=now, + ) + + +class TestConsolidationConfig: + """Tests for ConsolidationConfig.""" + + def test_default_values(self) -> None: + """Test default configuration values.""" + config = ConsolidationConfig() + + assert config.min_steps_for_episode == 2 + assert config.min_duration_seconds == 5.0 + assert config.min_confidence_for_fact == 0.6 + assert config.max_facts_per_episode == 10 + assert config.min_episodes_for_procedure == 3 + assert config.max_episode_age_days == 90 + assert config.batch_size == 100 + + def test_custom_values(self) -> None: + """Test custom configuration values.""" + config = ConsolidationConfig( + min_steps_for_episode=5, + batch_size=50, + ) + + assert config.min_steps_for_episode == 5 + assert config.batch_size == 50 + + +class TestConsolidationResult: + """Tests for ConsolidationResult.""" + + def test_creation(self) -> None: + """Test creating a consolidation result.""" + result = ConsolidationResult( + source_type="episodic", + target_type="semantic", + items_processed=10, + items_created=5, + ) + + assert result.source_type == "episodic" + assert result.target_type == "semantic" + assert result.items_processed == 10 + assert result.items_created == 5 + assert result.items_skipped == 0 + assert result.errors == [] + + def test_to_dict(self) -> None: + """Test converting to dictionary.""" + result = ConsolidationResult( + source_type="episodic", + target_type="semantic", + items_processed=10, + items_created=5, + errors=["test error"], + ) + + d = result.to_dict() + + assert d["source_type"] == "episodic" + assert d["target_type"] == "semantic" + assert d["items_processed"] == 10 + assert d["items_created"] == 5 + assert "test error" in d["errors"] + + +class TestSessionConsolidationResult: + """Tests for SessionConsolidationResult.""" + + def test_creation(self) -> None: + """Test creating a session consolidation result.""" + result = SessionConsolidationResult( + session_id="test-session", + episode_created=True, + episode_id=uuid4(), + scratchpad_entries=5, + ) + + assert result.session_id == "test-session" + assert result.episode_created is True + assert result.episode_id is not None + + +class TestNightlyConsolidationResult: + """Tests for NightlyConsolidationResult.""" + + def test_creation(self) -> None: + """Test creating a nightly consolidation result.""" + result = NightlyConsolidationResult( + started_at=_utcnow(), + ) + + assert result.started_at is not None + assert result.completed_at is None + assert result.total_episodes_processed == 0 + + def test_to_dict(self) -> None: + """Test converting to dictionary.""" + result = NightlyConsolidationResult( + started_at=_utcnow(), + completed_at=_utcnow(), + total_facts_created=5, + total_procedures_created=2, + ) + + d = result.to_dict() + + assert "started_at" in d + assert "completed_at" in d + assert d["total_facts_created"] == 5 + assert d["total_procedures_created"] == 2 + + +class TestMemoryConsolidationService: + """Tests for MemoryConsolidationService.""" + + @pytest.fixture + def mock_session(self) -> AsyncMock: + """Create a mock database session.""" + return AsyncMock() + + @pytest.fixture + def service(self, mock_session: AsyncMock) -> MemoryConsolidationService: + """Create a consolidation service with mocked dependencies.""" + return MemoryConsolidationService( + session=mock_session, + config=ConsolidationConfig(), + ) + + # ========================================================================= + # Session Consolidation Tests + # ========================================================================= + + @pytest.mark.asyncio + async def test_consolidate_session_insufficient_steps( + self, service: MemoryConsolidationService + ) -> None: + """Test session not consolidated when insufficient steps.""" + mock_working_memory = AsyncMock() + task_state = make_task_state(current_step=1) # Less than min_steps_for_episode + mock_working_memory.get_task_state.return_value = task_state + + result = await service.consolidate_session( + working_memory=mock_working_memory, + project_id=uuid4(), + session_id="test-session", + ) + + assert result.episode_created is False + assert result.episode_id is None + + @pytest.mark.asyncio + async def test_consolidate_session_no_task_state( + self, service: MemoryConsolidationService + ) -> None: + """Test session not consolidated when no task state.""" + mock_working_memory = AsyncMock() + mock_working_memory.get_task_state.return_value = None + + result = await service.consolidate_session( + working_memory=mock_working_memory, + project_id=uuid4(), + session_id="test-session", + ) + + assert result.episode_created is False + + @pytest.mark.asyncio + async def test_consolidate_session_success( + self, service: MemoryConsolidationService, mock_session: AsyncMock + ) -> None: + """Test successful session consolidation.""" + mock_working_memory = AsyncMock() + task_state = make_task_state( + current_step=5, + progress_percent=100.0, + status="complete", + ) + mock_working_memory.get_task_state.return_value = task_state + mock_working_memory.get_scratchpad.return_value = ["step1", "step2"] + mock_working_memory.get_all.return_value = {"key1": "value1"} + + # Mock episodic memory + mock_episode = make_episode() + with patch.object( + service, "_get_episodic", new_callable=AsyncMock + ) as mock_get_episodic: + mock_episodic = AsyncMock() + mock_episodic.record_episode.return_value = mock_episode + mock_get_episodic.return_value = mock_episodic + + result = await service.consolidate_session( + working_memory=mock_working_memory, + project_id=uuid4(), + session_id="test-session", + ) + + assert result.episode_created is True + assert result.episode_id == mock_episode.id + assert result.scratchpad_entries == 2 + + # ========================================================================= + # Outcome Determination Tests + # ========================================================================= + + def test_determine_session_outcome_success( + self, service: MemoryConsolidationService + ) -> None: + """Test outcome determination for successful session.""" + task_state = make_task_state(status="complete", progress_percent=100.0) + outcome = service._determine_session_outcome(task_state) + assert outcome == Outcome.SUCCESS + + def test_determine_session_outcome_failure( + self, service: MemoryConsolidationService + ) -> None: + """Test outcome determination for failed session.""" + task_state = make_task_state(status="error", progress_percent=25.0) + outcome = service._determine_session_outcome(task_state) + assert outcome == Outcome.FAILURE + + def test_determine_session_outcome_partial( + self, service: MemoryConsolidationService + ) -> None: + """Test outcome determination for partial session.""" + task_state = make_task_state(status="stopped", progress_percent=60.0) + outcome = service._determine_session_outcome(task_state) + assert outcome == Outcome.PARTIAL + + def test_determine_session_outcome_none( + self, service: MemoryConsolidationService + ) -> None: + """Test outcome determination with no task state.""" + outcome = service._determine_session_outcome(None) + assert outcome == Outcome.PARTIAL + + # ========================================================================= + # Action Building Tests + # ========================================================================= + + def test_build_actions_from_session( + self, service: MemoryConsolidationService + ) -> None: + """Test building actions from session data.""" + scratchpad = ["thought 1", "thought 2"] + variables = {"var1": "value1"} + task_state = make_task_state() + + actions = service._build_actions_from_session(scratchpad, variables, task_state) + + assert len(actions) == 3 # 2 scratchpad + 1 final state + assert actions[0]["type"] == "reasoning" + assert actions[2]["type"] == "final_state" + + def test_build_context_summary(self, service: MemoryConsolidationService) -> None: + """Test building context summary.""" + task_state = make_task_state( + description="Test Task", + progress_percent=75.0, + ) + variables = {"key": "value"} + + summary = service._build_context_summary(task_state, variables) + + assert "Test Task" in summary + assert "75.0%" in summary + + # ========================================================================= + # Importance Calculation Tests + # ========================================================================= + + def test_calculate_session_importance_base( + self, service: MemoryConsolidationService + ) -> None: + """Test base importance calculation.""" + task_state = make_task_state(total_steps=3) # Below threshold + importance = service._calculate_session_importance( + task_state, Outcome.SUCCESS, [] + ) + + assert importance == 0.5 # Base score + + def test_calculate_session_importance_failure( + self, service: MemoryConsolidationService + ) -> None: + """Test importance boost for failures.""" + task_state = make_task_state(total_steps=3) # Below threshold + importance = service._calculate_session_importance( + task_state, Outcome.FAILURE, [] + ) + + assert importance == 0.8 # Base (0.5) + failure boost (0.3) + + def test_calculate_session_importance_complex( + self, service: MemoryConsolidationService + ) -> None: + """Test importance for complex session.""" + task_state = make_task_state(total_steps=10) + actions = [{"step": i} for i in range(6)] + importance = service._calculate_session_importance( + task_state, Outcome.SUCCESS, actions + ) + + # Base (0.5) + many steps (0.1) + many actions (0.1) + assert importance == 0.7 + + # ========================================================================= + # Episode to Fact Consolidation Tests + # ========================================================================= + + @pytest.mark.asyncio + async def test_consolidate_episodes_to_facts_empty( + self, service: MemoryConsolidationService + ) -> None: + """Test consolidation with no episodes.""" + with patch.object( + service, "_get_episodic", new_callable=AsyncMock + ) as mock_get_episodic: + mock_episodic = AsyncMock() + mock_episodic.get_recent.return_value = [] + mock_get_episodic.return_value = mock_episodic + + result = await service.consolidate_episodes_to_facts( + project_id=uuid4(), + ) + + assert result.items_processed == 0 + assert result.items_created == 0 + + @pytest.mark.asyncio + async def test_consolidate_episodes_to_facts_success( + self, service: MemoryConsolidationService + ) -> None: + """Test successful fact extraction.""" + episode = make_episode( + lessons_learned=["Always check return values"], + ) + + mock_fact = MagicMock() + mock_fact.reinforcement_count = 1 # New fact + + with ( + patch.object( + service, "_get_episodic", new_callable=AsyncMock + ) as mock_get_episodic, + patch.object( + service, "_get_semantic", new_callable=AsyncMock + ) as mock_get_semantic, + ): + mock_episodic = AsyncMock() + mock_episodic.get_recent.return_value = [episode] + mock_get_episodic.return_value = mock_episodic + + mock_semantic = AsyncMock() + mock_semantic.store_fact.return_value = mock_fact + mock_get_semantic.return_value = mock_semantic + + result = await service.consolidate_episodes_to_facts( + project_id=uuid4(), + ) + + assert result.items_processed == 1 + # At least one fact should be created from lesson + assert result.items_created >= 0 + + # ========================================================================= + # Episode to Procedure Consolidation Tests + # ========================================================================= + + @pytest.mark.asyncio + async def test_consolidate_episodes_to_procedures_insufficient( + self, service: MemoryConsolidationService + ) -> None: + """Test consolidation with insufficient episodes.""" + # Only 1 episode - less than min_episodes_for_procedure (3) + episode = make_episode() + + with patch.object( + service, "_get_episodic", new_callable=AsyncMock + ) as mock_get_episodic: + mock_episodic = AsyncMock() + mock_episodic.get_by_outcome.return_value = [episode] + mock_get_episodic.return_value = mock_episodic + + result = await service.consolidate_episodes_to_procedures( + project_id=uuid4(), + ) + + assert result.items_processed == 1 + assert result.items_created == 0 + assert result.items_skipped == 1 + + @pytest.mark.asyncio + async def test_consolidate_episodes_to_procedures_success( + self, service: MemoryConsolidationService + ) -> None: + """Test successful procedure creation.""" + # Create enough episodes for a procedure + episodes = [ + make_episode( + task_type="deploy", + actions=[{"type": "step1"}, {"type": "step2"}, {"type": "step3"}], + ) + for _ in range(5) + ] + + mock_procedure = MagicMock() + + with ( + patch.object( + service, "_get_episodic", new_callable=AsyncMock + ) as mock_get_episodic, + patch.object( + service, "_get_procedural", new_callable=AsyncMock + ) as mock_get_procedural, + ): + mock_episodic = AsyncMock() + mock_episodic.get_by_outcome.return_value = episodes + mock_get_episodic.return_value = mock_episodic + + mock_procedural = AsyncMock() + mock_procedural.find_matching.return_value = [] # No existing procedure + mock_procedural.record_procedure.return_value = mock_procedure + mock_get_procedural.return_value = mock_procedural + + result = await service.consolidate_episodes_to_procedures( + project_id=uuid4(), + ) + + assert result.items_processed == 5 + assert result.items_created == 1 + + # ========================================================================= + # Common Steps Extraction Tests + # ========================================================================= + + def test_extract_common_steps(self, service: MemoryConsolidationService) -> None: + """Test extracting steps from episodes.""" + episodes = [ + make_episode( + outcome=Outcome.SUCCESS, + importance_score=0.8, + actions=[ + {"type": "step1", "content": "First step"}, + {"type": "step2", "content": "Second step"}, + ], + ), + make_episode( + outcome=Outcome.SUCCESS, + importance_score=0.5, + actions=[{"type": "simple"}], + ), + ] + + steps = service._extract_common_steps(episodes) + + assert len(steps) == 2 + assert steps[0]["order"] == 1 + assert steps[0]["action"] == "step1" + + # ========================================================================= + # Pruning Tests + # ========================================================================= + + def test_should_prune_episode_old_low_importance( + self, service: MemoryConsolidationService + ) -> None: + """Test pruning old, low-importance episode.""" + old_date = _utcnow() - timedelta(days=100) + episode = make_episode( + occurred_at=old_date, + importance_score=0.1, + outcome=Outcome.SUCCESS, + ) + cutoff = _utcnow() - timedelta(days=90) + + should_prune = service._should_prune_episode(episode, cutoff, 0.2) + + assert should_prune is True + + def test_should_prune_episode_recent( + self, service: MemoryConsolidationService + ) -> None: + """Test not pruning recent episode.""" + recent_date = _utcnow() - timedelta(days=30) + episode = make_episode( + occurred_at=recent_date, + importance_score=0.1, + ) + cutoff = _utcnow() - timedelta(days=90) + + should_prune = service._should_prune_episode(episode, cutoff, 0.2) + + assert should_prune is False + + def test_should_prune_episode_failure_protected( + self, service: MemoryConsolidationService + ) -> None: + """Test not pruning failure (with keep_all_failures=True).""" + old_date = _utcnow() - timedelta(days=100) + episode = make_episode( + occurred_at=old_date, + importance_score=0.1, + outcome=Outcome.FAILURE, + ) + cutoff = _utcnow() - timedelta(days=90) + + should_prune = service._should_prune_episode(episode, cutoff, 0.2) + + # Config has keep_all_failures=True by default + assert should_prune is False + + def test_should_prune_episode_with_lessons_protected( + self, service: MemoryConsolidationService + ) -> None: + """Test not pruning episode with lessons.""" + old_date = _utcnow() - timedelta(days=100) + episode = make_episode( + occurred_at=old_date, + importance_score=0.1, + lessons_learned=["Important lesson"], + ) + cutoff = _utcnow() - timedelta(days=90) + + should_prune = service._should_prune_episode(episode, cutoff, 0.2) + + # Config has keep_all_with_lessons=True by default + assert should_prune is False + + def test_should_prune_episode_high_importance_protected( + self, service: MemoryConsolidationService + ) -> None: + """Test not pruning high importance episode.""" + old_date = _utcnow() - timedelta(days=100) + episode = make_episode( + occurred_at=old_date, + importance_score=0.8, + ) + cutoff = _utcnow() - timedelta(days=90) + + should_prune = service._should_prune_episode(episode, cutoff, 0.2) + + assert should_prune is False + + @pytest.mark.asyncio + async def test_prune_old_episodes( + self, service: MemoryConsolidationService + ) -> None: + """Test episode pruning.""" + old_episode = make_episode( + occurred_at=_utcnow() - timedelta(days=100), + importance_score=0.1, + outcome=Outcome.SUCCESS, + lessons_learned=[], + ) + + with patch.object( + service, "_get_episodic", new_callable=AsyncMock + ) as mock_get_episodic: + mock_episodic = AsyncMock() + mock_episodic.get_recent.return_value = [old_episode] + mock_episodic.delete.return_value = True + mock_get_episodic.return_value = mock_episodic + + result = await service.prune_old_episodes(project_id=uuid4()) + + assert result.items_processed == 1 + assert result.items_pruned == 1 + + # ========================================================================= + # Nightly Consolidation Tests + # ========================================================================= + + @pytest.mark.asyncio + async def test_run_nightly_consolidation( + self, service: MemoryConsolidationService + ) -> None: + """Test nightly consolidation workflow.""" + with ( + patch.object( + service, + "consolidate_episodes_to_facts", + new_callable=AsyncMock, + ) as mock_facts, + patch.object( + service, + "consolidate_episodes_to_procedures", + new_callable=AsyncMock, + ) as mock_procedures, + patch.object( + service, + "prune_old_episodes", + new_callable=AsyncMock, + ) as mock_prune, + ): + mock_facts.return_value = ConsolidationResult( + source_type="episodic", + target_type="semantic", + items_processed=10, + items_created=5, + ) + mock_procedures.return_value = ConsolidationResult( + source_type="episodic", + target_type="procedural", + items_processed=10, + items_created=2, + ) + mock_prune.return_value = ConsolidationResult( + source_type="episodic", + target_type="pruned", + items_pruned=3, + ) + + result = await service.run_nightly_consolidation(project_id=uuid4()) + + assert result.completed_at is not None + assert result.total_facts_created == 5 + assert result.total_procedures_created == 2 + assert result.total_pruned == 3 + assert result.total_episodes_processed == 20 + + @pytest.mark.asyncio + async def test_run_nightly_consolidation_with_errors( + self, service: MemoryConsolidationService + ) -> None: + """Test nightly consolidation handles errors.""" + with ( + patch.object( + service, + "consolidate_episodes_to_facts", + new_callable=AsyncMock, + ) as mock_facts, + patch.object( + service, + "consolidate_episodes_to_procedures", + new_callable=AsyncMock, + ) as mock_procedures, + patch.object( + service, + "prune_old_episodes", + new_callable=AsyncMock, + ) as mock_prune, + ): + mock_facts.return_value = ConsolidationResult( + source_type="episodic", + target_type="semantic", + errors=["fact error"], + ) + mock_procedures.return_value = ConsolidationResult( + source_type="episodic", + target_type="procedural", + ) + mock_prune.return_value = ConsolidationResult( + source_type="episodic", + target_type="pruned", + ) + + result = await service.run_nightly_consolidation(project_id=uuid4()) + + assert "fact error" in result.errors