Files
syndarix/backend/app/services/memory/consolidation/service.py
Felipe Cardoso 1670e05e0d feat(memory): implement memory consolidation service and tasks (#95)
- Add MemoryConsolidationService with Working→Episodic→Semantic/Procedural transfer
- Add Celery tasks for session and nightly consolidation
- Implement memory pruning with importance-based retention
- Add comprehensive test suite (32 tests)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 03:04:28 +01:00

919 lines
31 KiB
Python

# app/services/memory/consolidation/service.py
"""
Memory Consolidation Service.
Transfers and extracts knowledge between memory tiers:
- Working -> Episodic (session end)
- Episodic -> Semantic (learn facts)
- Episodic -> Procedural (learn procedures)
Also handles memory pruning and importance-based retention.
"""
import logging
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import Any
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.memory.episodic.memory import EpisodicMemory
from app.services.memory.procedural.memory import ProceduralMemory
from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor
from app.services.memory.semantic.memory import SemanticMemory
from app.services.memory.types import (
Episode,
EpisodeCreate,
Outcome,
ProcedureCreate,
TaskState,
)
from app.services.memory.working.memory import WorkingMemory
logger = logging.getLogger(__name__)
@dataclass
class ConsolidationConfig:
"""Configuration for memory consolidation."""
# Working -> Episodic thresholds
min_steps_for_episode: int = 2
min_duration_seconds: float = 5.0
# Episodic -> Semantic thresholds
min_confidence_for_fact: float = 0.6
max_facts_per_episode: int = 10
reinforce_existing_facts: bool = True
# Episodic -> Procedural thresholds
min_episodes_for_procedure: int = 3
min_success_rate_for_procedure: float = 0.7
min_steps_for_procedure: int = 2
# Pruning thresholds
max_episode_age_days: int = 90
min_importance_to_keep: float = 0.2
keep_all_failures: bool = True
keep_all_with_lessons: bool = True
# Batch sizes
batch_size: int = 100
@dataclass
class ConsolidationResult:
"""Result of a consolidation operation."""
source_type: str
target_type: str
items_processed: int = 0
items_created: int = 0
items_updated: int = 0
items_skipped: int = 0
items_pruned: int = 0
errors: list[str] = field(default_factory=list)
duration_seconds: float = 0.0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"source_type": self.source_type,
"target_type": self.target_type,
"items_processed": self.items_processed,
"items_created": self.items_created,
"items_updated": self.items_updated,
"items_skipped": self.items_skipped,
"items_pruned": self.items_pruned,
"errors": self.errors,
"duration_seconds": self.duration_seconds,
}
@dataclass
class SessionConsolidationResult:
"""Result of consolidating a session's working memory to episodic."""
session_id: str
episode_created: bool = False
episode_id: UUID | None = None
scratchpad_entries: int = 0
variables_captured: int = 0
error: str | None = None
@dataclass
class NightlyConsolidationResult:
"""Result of nightly consolidation run."""
started_at: datetime
completed_at: datetime | None = None
episodic_to_semantic: ConsolidationResult | None = None
episodic_to_procedural: ConsolidationResult | None = None
pruning: ConsolidationResult | None = None
total_episodes_processed: int = 0
total_facts_created: int = 0
total_procedures_created: int = 0
total_pruned: int = 0
errors: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"started_at": self.started_at.isoformat(),
"completed_at": self.completed_at.isoformat()
if self.completed_at
else None,
"episodic_to_semantic": (
self.episodic_to_semantic.to_dict()
if self.episodic_to_semantic
else None
),
"episodic_to_procedural": (
self.episodic_to_procedural.to_dict()
if self.episodic_to_procedural
else None
),
"pruning": self.pruning.to_dict() if self.pruning else None,
"total_episodes_processed": self.total_episodes_processed,
"total_facts_created": self.total_facts_created,
"total_procedures_created": self.total_procedures_created,
"total_pruned": self.total_pruned,
"errors": self.errors,
}
class MemoryConsolidationService:
"""
Service for consolidating memories between tiers.
Responsibilities:
- Transfer working memory to episodic at session end
- Extract facts from episodes to semantic memory
- Learn procedures from successful episode patterns
- Prune old/low-value memories
"""
def __init__(
self,
session: AsyncSession,
config: ConsolidationConfig | None = None,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize consolidation service.
Args:
session: Database session
config: Consolidation configuration
embedding_generator: Optional embedding generator
"""
self._session = session
self._config = config or ConsolidationConfig()
self._embedding_generator = embedding_generator
self._fact_extractor: FactExtractor = get_fact_extractor()
# Memory services (lazy initialized)
self._episodic: EpisodicMemory | None = None
self._semantic: SemanticMemory | None = None
self._procedural: ProceduralMemory | None = None
async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service."""
if self._episodic is None:
self._episodic = await EpisodicMemory.create(
self._session, self._embedding_generator
)
return self._episodic
async def _get_semantic(self) -> SemanticMemory:
"""Get or create semantic memory service."""
if self._semantic is None:
self._semantic = await SemanticMemory.create(
self._session, self._embedding_generator
)
return self._semantic
async def _get_procedural(self) -> ProceduralMemory:
"""Get or create procedural memory service."""
if self._procedural is None:
self._procedural = await ProceduralMemory.create(
self._session, self._embedding_generator
)
return self._procedural
# =========================================================================
# Working -> Episodic Consolidation
# =========================================================================
async def consolidate_session(
self,
working_memory: WorkingMemory,
project_id: UUID,
session_id: str,
task_type: str = "session_task",
agent_instance_id: UUID | None = None,
agent_type_id: UUID | None = None,
) -> SessionConsolidationResult:
"""
Consolidate a session's working memory to episodic memory.
Called at session end to transfer relevant session data
into a persistent episode.
Args:
working_memory: The session's working memory
project_id: Project ID
session_id: Session ID
task_type: Type of task performed
agent_instance_id: Optional agent instance
agent_type_id: Optional agent type
Returns:
SessionConsolidationResult with outcome details
"""
result = SessionConsolidationResult(session_id=session_id)
try:
# Get task state
task_state = await working_memory.get_task_state()
# Check if there's enough content to consolidate
if not self._should_consolidate_session(task_state):
logger.debug(
f"Skipping consolidation for session {session_id}: insufficient content"
)
return result
# Gather scratchpad entries
scratchpad = await working_memory.get_scratchpad()
result.scratchpad_entries = len(scratchpad)
# Gather user variables
all_data = await working_memory.get_all()
result.variables_captured = len(all_data)
# Determine outcome
outcome = self._determine_session_outcome(task_state)
# Build actions from scratchpad and variables
actions = self._build_actions_from_session(scratchpad, all_data, task_state)
# Build context summary
context_summary = self._build_context_summary(task_state, all_data)
# Extract lessons learned
lessons = self._extract_session_lessons(task_state, outcome)
# Calculate importance
importance = self._calculate_session_importance(
task_state, outcome, actions
)
# Create episode
episode_data = EpisodeCreate(
project_id=project_id,
session_id=session_id,
task_type=task_type,
task_description=task_state.description
if task_state
else "Session task",
actions=actions,
context_summary=context_summary,
outcome=outcome,
outcome_details=task_state.status if task_state else "",
duration_seconds=self._calculate_duration(task_state),
tokens_used=0, # Would need to track this in working memory
lessons_learned=lessons,
importance_score=importance,
agent_instance_id=agent_instance_id,
agent_type_id=agent_type_id,
)
episodic = await self._get_episodic()
episode = await episodic.record_episode(episode_data)
result.episode_created = True
result.episode_id = episode.id
logger.info(
f"Consolidated session {session_id} to episode {episode.id} "
f"({len(actions)} actions, outcome={outcome.value})"
)
except Exception as e:
result.error = str(e)
logger.exception(f"Failed to consolidate session {session_id}")
return result
def _should_consolidate_session(self, task_state: TaskState | None) -> bool:
"""Check if session has enough content to consolidate."""
if task_state is None:
return False
# Check minimum steps
if task_state.current_step < self._config.min_steps_for_episode:
return False
return True
def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome:
"""Determine outcome from task state."""
if task_state is None:
return Outcome.PARTIAL
status = task_state.status.lower() if task_state.status else ""
progress = task_state.progress_percent
if "success" in status or "complete" in status or progress >= 100:
return Outcome.SUCCESS
if "fail" in status or "error" in status:
return Outcome.FAILURE
if progress >= 50:
return Outcome.PARTIAL
return Outcome.FAILURE
def _build_actions_from_session(
self,
scratchpad: list[str],
variables: dict[str, Any],
task_state: TaskState | None,
) -> list[dict[str, Any]]:
"""Build action list from session data."""
actions: list[dict[str, Any]] = []
# Add scratchpad entries as actions
for i, entry in enumerate(scratchpad):
actions.append(
{
"step": i + 1,
"type": "reasoning",
"content": entry[:500], # Truncate long entries
}
)
# Add final state
if task_state:
actions.append(
{
"step": len(scratchpad) + 1,
"type": "final_state",
"current_step": task_state.current_step,
"total_steps": task_state.total_steps,
"progress": task_state.progress_percent,
"status": task_state.status,
}
)
return actions
def _build_context_summary(
self,
task_state: TaskState | None,
variables: dict[str, Any],
) -> str:
"""Build context summary from session data."""
parts = []
if task_state:
parts.append(f"Task: {task_state.description}")
parts.append(f"Progress: {task_state.progress_percent:.1f}%")
parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}")
# Include key variables
key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100}
if key_vars:
var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5])
parts.append(f"Variables: {var_str}")
return "; ".join(parts) if parts else "Session completed"
def _extract_session_lessons(
self,
task_state: TaskState | None,
outcome: Outcome,
) -> list[str]:
"""Extract lessons from session."""
lessons: list[str] = []
if task_state and task_state.status:
if outcome == Outcome.FAILURE:
lessons.append(
f"Task failed at step {task_state.current_step}: {task_state.status}"
)
elif outcome == Outcome.SUCCESS:
lessons.append(
f"Successfully completed in {task_state.current_step} steps"
)
return lessons
def _calculate_session_importance(
self,
task_state: TaskState | None,
outcome: Outcome,
actions: list[dict[str, Any]],
) -> float:
"""Calculate importance score for session."""
score = 0.5 # Base score
# Failures are important to learn from
if outcome == Outcome.FAILURE:
score += 0.3
# Many steps means complex task
if task_state and task_state.total_steps >= 5:
score += 0.1
# Many actions means detailed reasoning
if len(actions) >= 5:
score += 0.1
return min(1.0, score)
def _calculate_duration(self, task_state: TaskState | None) -> float:
"""Calculate session duration."""
if task_state is None:
return 0.0
if task_state.started_at and task_state.updated_at:
delta = task_state.updated_at - task_state.started_at
return delta.total_seconds()
return 0.0
# =========================================================================
# Episodic -> Semantic Consolidation
# =========================================================================
async def consolidate_episodes_to_facts(
self,
project_id: UUID,
since: datetime | None = None,
limit: int | None = None,
) -> ConsolidationResult:
"""
Extract facts from episodic memories to semantic memory.
Args:
project_id: Project to consolidate
since: Only process episodes since this time
limit: Maximum episodes to process
Returns:
ConsolidationResult with extraction statistics
"""
start_time = datetime.now(UTC)
result = ConsolidationResult(
source_type="episodic",
target_type="semantic",
)
try:
episodic = await self._get_episodic()
semantic = await self._get_semantic()
# Get episodes to process
since_time = since or datetime.now(UTC) - timedelta(days=1)
episodes = await episodic.get_recent(
project_id,
limit=limit or self._config.batch_size,
since=since_time,
)
for episode in episodes:
result.items_processed += 1
try:
# Extract facts using the extractor
extracted_facts = self._fact_extractor.extract_from_episode(episode)
for extracted_fact in extracted_facts:
if (
extracted_fact.confidence
< self._config.min_confidence_for_fact
):
result.items_skipped += 1
continue
# Create fact (store_fact handles deduplication/reinforcement)
fact_create = extracted_fact.to_fact_create(
project_id=project_id,
source_episode_ids=[episode.id],
)
# store_fact automatically reinforces if fact already exists
fact = await semantic.store_fact(fact_create)
# Check if this was a new fact or reinforced existing
if fact.reinforcement_count == 1:
result.items_created += 1
else:
result.items_updated += 1
except Exception as e:
result.errors.append(f"Episode {episode.id}: {e}")
logger.warning(
f"Failed to extract facts from episode {episode.id}: {e}"
)
except Exception as e:
result.errors.append(f"Consolidation failed: {e}")
logger.exception("Failed episodic -> semantic consolidation")
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
logger.info(
f"Episodic -> Semantic consolidation: "
f"{result.items_processed} processed, "
f"{result.items_created} created, "
f"{result.items_updated} reinforced"
)
return result
# =========================================================================
# Episodic -> Procedural Consolidation
# =========================================================================
async def consolidate_episodes_to_procedures(
self,
project_id: UUID,
agent_type_id: UUID | None = None,
since: datetime | None = None,
) -> ConsolidationResult:
"""
Learn procedures from patterns in episodic memories.
Identifies recurring successful patterns and creates/updates
procedures to capture them.
Args:
project_id: Project to consolidate
agent_type_id: Optional filter by agent type
since: Only process episodes since this time
Returns:
ConsolidationResult with procedure statistics
"""
start_time = datetime.now(UTC)
result = ConsolidationResult(
source_type="episodic",
target_type="procedural",
)
try:
episodic = await self._get_episodic()
procedural = await self._get_procedural()
# Get successful episodes
since_time = since or datetime.now(UTC) - timedelta(days=7)
episodes = await episodic.get_by_outcome(
project_id,
outcome=Outcome.SUCCESS,
limit=self._config.batch_size,
agent_instance_id=None, # Get all agent instances
)
# Group by task type
task_groups: dict[str, list[Episode]] = {}
for episode in episodes:
if episode.occurred_at >= since_time:
if episode.task_type not in task_groups:
task_groups[episode.task_type] = []
task_groups[episode.task_type].append(episode)
result.items_processed = len(episodes)
# Process each task type group
for task_type, group in task_groups.items():
if len(group) < self._config.min_episodes_for_procedure:
result.items_skipped += len(group)
continue
try:
procedure_result = await self._learn_procedure_from_episodes(
procedural,
project_id,
agent_type_id,
task_type,
group,
)
if procedure_result == "created":
result.items_created += 1
elif procedure_result == "updated":
result.items_updated += 1
else:
result.items_skipped += 1
except Exception as e:
result.errors.append(f"Task type '{task_type}': {e}")
logger.warning(f"Failed to learn procedure for '{task_type}': {e}")
except Exception as e:
result.errors.append(f"Consolidation failed: {e}")
logger.exception("Failed episodic -> procedural consolidation")
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
logger.info(
f"Episodic -> Procedural consolidation: "
f"{result.items_processed} processed, "
f"{result.items_created} created, "
f"{result.items_updated} updated"
)
return result
async def _learn_procedure_from_episodes(
self,
procedural: ProceduralMemory,
project_id: UUID,
agent_type_id: UUID | None,
task_type: str,
episodes: list[Episode],
) -> str:
"""Learn or update a procedure from a set of episodes."""
# Calculate success rate for this pattern
success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
total_count = len(episodes)
success_rate = success_count / total_count if total_count > 0 else 0
if success_rate < self._config.min_success_rate_for_procedure:
return "skipped"
# Extract common steps from episodes
steps = self._extract_common_steps(episodes)
if len(steps) < self._config.min_steps_for_procedure:
return "skipped"
# Check for existing procedure
matching = await procedural.find_matching(
context=task_type,
project_id=project_id,
agent_type_id=agent_type_id,
limit=1,
)
if matching:
# Update existing procedure with new success
await procedural.record_outcome(
matching[0].id,
success=True,
)
return "updated"
else:
# Create new procedure
# Note: success_count starts at 1 in record_procedure
procedure_data = ProcedureCreate(
project_id=project_id,
agent_type_id=agent_type_id,
name=f"Procedure for {task_type}",
trigger_pattern=task_type,
steps=steps,
)
await procedural.record_procedure(procedure_data)
return "created"
def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]:
"""Extract common action steps from multiple episodes."""
# Simple heuristic: take the steps from the most successful episode
# with the most detailed actions
best_episode = max(
episodes,
key=lambda e: (
e.outcome == Outcome.SUCCESS,
len(e.actions),
e.importance_score,
),
)
steps: list[dict[str, Any]] = []
for i, action in enumerate(best_episode.actions):
step = {
"order": i + 1,
"action": action.get("type", "action"),
"description": action.get("content", str(action))[:500],
"parameters": action,
}
steps.append(step)
return steps
# =========================================================================
# Memory Pruning
# =========================================================================
async def prune_old_episodes(
self,
project_id: UUID,
max_age_days: int | None = None,
min_importance: float | None = None,
) -> ConsolidationResult:
"""
Prune old, low-value episodes.
Args:
project_id: Project to prune
max_age_days: Maximum age in days (default from config)
min_importance: Minimum importance to keep (default from config)
Returns:
ConsolidationResult with pruning statistics
"""
start_time = datetime.now(UTC)
result = ConsolidationResult(
source_type="episodic",
target_type="pruned",
)
max_age = max_age_days or self._config.max_episode_age_days
min_imp = min_importance or self._config.min_importance_to_keep
cutoff_date = datetime.now(UTC) - timedelta(days=max_age)
try:
episodic = await self._get_episodic()
# Get old episodes
# Note: In production, this would use a more efficient query
all_episodes = await episodic.get_recent(
project_id,
limit=self._config.batch_size * 10,
since=cutoff_date - timedelta(days=365), # Search past year
)
for episode in all_episodes:
result.items_processed += 1
# Check if should be pruned
if not self._should_prune_episode(episode, cutoff_date, min_imp):
result.items_skipped += 1
continue
try:
deleted = await episodic.delete(episode.id)
if deleted:
result.items_pruned += 1
else:
result.items_skipped += 1
except Exception as e:
result.errors.append(f"Episode {episode.id}: {e}")
except Exception as e:
result.errors.append(f"Pruning failed: {e}")
logger.exception("Failed episode pruning")
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
logger.info(
f"Episode pruning: {result.items_processed} processed, "
f"{result.items_pruned} pruned"
)
return result
def _should_prune_episode(
self,
episode: Episode,
cutoff_date: datetime,
min_importance: float,
) -> bool:
"""Determine if an episode should be pruned."""
# Keep recent episodes
if episode.occurred_at >= cutoff_date:
return False
# Keep failures if configured
if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE:
return False
# Keep episodes with lessons if configured
if self._config.keep_all_with_lessons and episode.lessons_learned:
return False
# Keep high-importance episodes
if episode.importance_score >= min_importance:
return False
return True
# =========================================================================
# Nightly Consolidation
# =========================================================================
async def run_nightly_consolidation(
self,
project_id: UUID,
agent_type_id: UUID | None = None,
) -> NightlyConsolidationResult:
"""
Run full nightly consolidation workflow.
This includes:
1. Extract facts from recent episodes
2. Learn procedures from successful patterns
3. Prune old, low-value memories
Args:
project_id: Project to consolidate
agent_type_id: Optional agent type filter
Returns:
NightlyConsolidationResult with all outcomes
"""
result = NightlyConsolidationResult(started_at=datetime.now(UTC))
logger.info(f"Starting nightly consolidation for project {project_id}")
try:
# Step 1: Episodic -> Semantic (last 24 hours)
since_yesterday = datetime.now(UTC) - timedelta(days=1)
result.episodic_to_semantic = await self.consolidate_episodes_to_facts(
project_id=project_id,
since=since_yesterday,
)
result.total_facts_created = result.episodic_to_semantic.items_created
# Step 2: Episodic -> Procedural (last 7 days)
since_week = datetime.now(UTC) - timedelta(days=7)
result.episodic_to_procedural = (
await self.consolidate_episodes_to_procedures(
project_id=project_id,
agent_type_id=agent_type_id,
since=since_week,
)
)
result.total_procedures_created = (
result.episodic_to_procedural.items_created
)
# Step 3: Prune old memories
result.pruning = await self.prune_old_episodes(project_id=project_id)
result.total_pruned = result.pruning.items_pruned
# Calculate totals
result.total_episodes_processed = (
result.episodic_to_semantic.items_processed
if result.episodic_to_semantic
else 0
) + (
result.episodic_to_procedural.items_processed
if result.episodic_to_procedural
else 0
)
# Collect all errors
if result.episodic_to_semantic and result.episodic_to_semantic.errors:
result.errors.extend(result.episodic_to_semantic.errors)
if result.episodic_to_procedural and result.episodic_to_procedural.errors:
result.errors.extend(result.episodic_to_procedural.errors)
if result.pruning and result.pruning.errors:
result.errors.extend(result.pruning.errors)
except Exception as e:
result.errors.append(f"Nightly consolidation failed: {e}")
logger.exception("Nightly consolidation failed")
result.completed_at = datetime.now(UTC)
duration = (result.completed_at - result.started_at).total_seconds()
logger.info(
f"Nightly consolidation completed in {duration:.1f}s: "
f"{result.total_facts_created} facts, "
f"{result.total_procedures_created} procedures, "
f"{result.total_pruned} pruned"
)
return result
# Singleton instance
_consolidation_service: MemoryConsolidationService | None = None
async def get_consolidation_service(
session: AsyncSession,
config: ConsolidationConfig | None = None,
) -> MemoryConsolidationService:
"""
Get or create the memory consolidation service.
Args:
session: Database session
config: Optional configuration
Returns:
MemoryConsolidationService instance
"""
global _consolidation_service
if _consolidation_service is None:
_consolidation_service = MemoryConsolidationService(
session=session, config=config
)
return _consolidation_service