forked from cardosofelipe/fast-next-template
- Add MemoryConsolidationService with Working→Episodic→Semantic/Procedural transfer - Add Celery tasks for session and nightly consolidation - Implement memory pruning with importance-based retention - Add comprehensive test suite (32 tests) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
919 lines
31 KiB
Python
919 lines
31 KiB
Python
# app/services/memory/consolidation/service.py
|
|
"""
|
|
Memory Consolidation Service.
|
|
|
|
Transfers and extracts knowledge between memory tiers:
|
|
- Working -> Episodic (session end)
|
|
- Episodic -> Semantic (learn facts)
|
|
- Episodic -> Procedural (learn procedures)
|
|
|
|
Also handles memory pruning and importance-based retention.
|
|
"""
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
from datetime import UTC, datetime, timedelta
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.services.memory.episodic.memory import EpisodicMemory
|
|
from app.services.memory.procedural.memory import ProceduralMemory
|
|
from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor
|
|
from app.services.memory.semantic.memory import SemanticMemory
|
|
from app.services.memory.types import (
|
|
Episode,
|
|
EpisodeCreate,
|
|
Outcome,
|
|
ProcedureCreate,
|
|
TaskState,
|
|
)
|
|
from app.services.memory.working.memory import WorkingMemory
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ConsolidationConfig:
|
|
"""Configuration for memory consolidation."""
|
|
|
|
# Working -> Episodic thresholds
|
|
min_steps_for_episode: int = 2
|
|
min_duration_seconds: float = 5.0
|
|
|
|
# Episodic -> Semantic thresholds
|
|
min_confidence_for_fact: float = 0.6
|
|
max_facts_per_episode: int = 10
|
|
reinforce_existing_facts: bool = True
|
|
|
|
# Episodic -> Procedural thresholds
|
|
min_episodes_for_procedure: int = 3
|
|
min_success_rate_for_procedure: float = 0.7
|
|
min_steps_for_procedure: int = 2
|
|
|
|
# Pruning thresholds
|
|
max_episode_age_days: int = 90
|
|
min_importance_to_keep: float = 0.2
|
|
keep_all_failures: bool = True
|
|
keep_all_with_lessons: bool = True
|
|
|
|
# Batch sizes
|
|
batch_size: int = 100
|
|
|
|
|
|
@dataclass
|
|
class ConsolidationResult:
|
|
"""Result of a consolidation operation."""
|
|
|
|
source_type: str
|
|
target_type: str
|
|
items_processed: int = 0
|
|
items_created: int = 0
|
|
items_updated: int = 0
|
|
items_skipped: int = 0
|
|
items_pruned: int = 0
|
|
errors: list[str] = field(default_factory=list)
|
|
duration_seconds: float = 0.0
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"source_type": self.source_type,
|
|
"target_type": self.target_type,
|
|
"items_processed": self.items_processed,
|
|
"items_created": self.items_created,
|
|
"items_updated": self.items_updated,
|
|
"items_skipped": self.items_skipped,
|
|
"items_pruned": self.items_pruned,
|
|
"errors": self.errors,
|
|
"duration_seconds": self.duration_seconds,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class SessionConsolidationResult:
|
|
"""Result of consolidating a session's working memory to episodic."""
|
|
|
|
session_id: str
|
|
episode_created: bool = False
|
|
episode_id: UUID | None = None
|
|
scratchpad_entries: int = 0
|
|
variables_captured: int = 0
|
|
error: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class NightlyConsolidationResult:
|
|
"""Result of nightly consolidation run."""
|
|
|
|
started_at: datetime
|
|
completed_at: datetime | None = None
|
|
episodic_to_semantic: ConsolidationResult | None = None
|
|
episodic_to_procedural: ConsolidationResult | None = None
|
|
pruning: ConsolidationResult | None = None
|
|
total_episodes_processed: int = 0
|
|
total_facts_created: int = 0
|
|
total_procedures_created: int = 0
|
|
total_pruned: int = 0
|
|
errors: list[str] = field(default_factory=list)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
"started_at": self.started_at.isoformat(),
|
|
"completed_at": self.completed_at.isoformat()
|
|
if self.completed_at
|
|
else None,
|
|
"episodic_to_semantic": (
|
|
self.episodic_to_semantic.to_dict()
|
|
if self.episodic_to_semantic
|
|
else None
|
|
),
|
|
"episodic_to_procedural": (
|
|
self.episodic_to_procedural.to_dict()
|
|
if self.episodic_to_procedural
|
|
else None
|
|
),
|
|
"pruning": self.pruning.to_dict() if self.pruning else None,
|
|
"total_episodes_processed": self.total_episodes_processed,
|
|
"total_facts_created": self.total_facts_created,
|
|
"total_procedures_created": self.total_procedures_created,
|
|
"total_pruned": self.total_pruned,
|
|
"errors": self.errors,
|
|
}
|
|
|
|
|
|
class MemoryConsolidationService:
|
|
"""
|
|
Service for consolidating memories between tiers.
|
|
|
|
Responsibilities:
|
|
- Transfer working memory to episodic at session end
|
|
- Extract facts from episodes to semantic memory
|
|
- Learn procedures from successful episode patterns
|
|
- Prune old/low-value memories
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
session: AsyncSession,
|
|
config: ConsolidationConfig | None = None,
|
|
embedding_generator: Any | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize consolidation service.
|
|
|
|
Args:
|
|
session: Database session
|
|
config: Consolidation configuration
|
|
embedding_generator: Optional embedding generator
|
|
"""
|
|
self._session = session
|
|
self._config = config or ConsolidationConfig()
|
|
self._embedding_generator = embedding_generator
|
|
self._fact_extractor: FactExtractor = get_fact_extractor()
|
|
|
|
# Memory services (lazy initialized)
|
|
self._episodic: EpisodicMemory | None = None
|
|
self._semantic: SemanticMemory | None = None
|
|
self._procedural: ProceduralMemory | None = None
|
|
|
|
async def _get_episodic(self) -> EpisodicMemory:
|
|
"""Get or create episodic memory service."""
|
|
if self._episodic is None:
|
|
self._episodic = await EpisodicMemory.create(
|
|
self._session, self._embedding_generator
|
|
)
|
|
return self._episodic
|
|
|
|
async def _get_semantic(self) -> SemanticMemory:
|
|
"""Get or create semantic memory service."""
|
|
if self._semantic is None:
|
|
self._semantic = await SemanticMemory.create(
|
|
self._session, self._embedding_generator
|
|
)
|
|
return self._semantic
|
|
|
|
async def _get_procedural(self) -> ProceduralMemory:
|
|
"""Get or create procedural memory service."""
|
|
if self._procedural is None:
|
|
self._procedural = await ProceduralMemory.create(
|
|
self._session, self._embedding_generator
|
|
)
|
|
return self._procedural
|
|
|
|
# =========================================================================
|
|
# Working -> Episodic Consolidation
|
|
# =========================================================================
|
|
|
|
async def consolidate_session(
|
|
self,
|
|
working_memory: WorkingMemory,
|
|
project_id: UUID,
|
|
session_id: str,
|
|
task_type: str = "session_task",
|
|
agent_instance_id: UUID | None = None,
|
|
agent_type_id: UUID | None = None,
|
|
) -> SessionConsolidationResult:
|
|
"""
|
|
Consolidate a session's working memory to episodic memory.
|
|
|
|
Called at session end to transfer relevant session data
|
|
into a persistent episode.
|
|
|
|
Args:
|
|
working_memory: The session's working memory
|
|
project_id: Project ID
|
|
session_id: Session ID
|
|
task_type: Type of task performed
|
|
agent_instance_id: Optional agent instance
|
|
agent_type_id: Optional agent type
|
|
|
|
Returns:
|
|
SessionConsolidationResult with outcome details
|
|
"""
|
|
result = SessionConsolidationResult(session_id=session_id)
|
|
|
|
try:
|
|
# Get task state
|
|
task_state = await working_memory.get_task_state()
|
|
|
|
# Check if there's enough content to consolidate
|
|
if not self._should_consolidate_session(task_state):
|
|
logger.debug(
|
|
f"Skipping consolidation for session {session_id}: insufficient content"
|
|
)
|
|
return result
|
|
|
|
# Gather scratchpad entries
|
|
scratchpad = await working_memory.get_scratchpad()
|
|
result.scratchpad_entries = len(scratchpad)
|
|
|
|
# Gather user variables
|
|
all_data = await working_memory.get_all()
|
|
result.variables_captured = len(all_data)
|
|
|
|
# Determine outcome
|
|
outcome = self._determine_session_outcome(task_state)
|
|
|
|
# Build actions from scratchpad and variables
|
|
actions = self._build_actions_from_session(scratchpad, all_data, task_state)
|
|
|
|
# Build context summary
|
|
context_summary = self._build_context_summary(task_state, all_data)
|
|
|
|
# Extract lessons learned
|
|
lessons = self._extract_session_lessons(task_state, outcome)
|
|
|
|
# Calculate importance
|
|
importance = self._calculate_session_importance(
|
|
task_state, outcome, actions
|
|
)
|
|
|
|
# Create episode
|
|
episode_data = EpisodeCreate(
|
|
project_id=project_id,
|
|
session_id=session_id,
|
|
task_type=task_type,
|
|
task_description=task_state.description
|
|
if task_state
|
|
else "Session task",
|
|
actions=actions,
|
|
context_summary=context_summary,
|
|
outcome=outcome,
|
|
outcome_details=task_state.status if task_state else "",
|
|
duration_seconds=self._calculate_duration(task_state),
|
|
tokens_used=0, # Would need to track this in working memory
|
|
lessons_learned=lessons,
|
|
importance_score=importance,
|
|
agent_instance_id=agent_instance_id,
|
|
agent_type_id=agent_type_id,
|
|
)
|
|
|
|
episodic = await self._get_episodic()
|
|
episode = await episodic.record_episode(episode_data)
|
|
|
|
result.episode_created = True
|
|
result.episode_id = episode.id
|
|
|
|
logger.info(
|
|
f"Consolidated session {session_id} to episode {episode.id} "
|
|
f"({len(actions)} actions, outcome={outcome.value})"
|
|
)
|
|
|
|
except Exception as e:
|
|
result.error = str(e)
|
|
logger.exception(f"Failed to consolidate session {session_id}")
|
|
|
|
return result
|
|
|
|
def _should_consolidate_session(self, task_state: TaskState | None) -> bool:
|
|
"""Check if session has enough content to consolidate."""
|
|
if task_state is None:
|
|
return False
|
|
|
|
# Check minimum steps
|
|
if task_state.current_step < self._config.min_steps_for_episode:
|
|
return False
|
|
|
|
return True
|
|
|
|
def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome:
|
|
"""Determine outcome from task state."""
|
|
if task_state is None:
|
|
return Outcome.PARTIAL
|
|
|
|
status = task_state.status.lower() if task_state.status else ""
|
|
progress = task_state.progress_percent
|
|
|
|
if "success" in status or "complete" in status or progress >= 100:
|
|
return Outcome.SUCCESS
|
|
if "fail" in status or "error" in status:
|
|
return Outcome.FAILURE
|
|
if progress >= 50:
|
|
return Outcome.PARTIAL
|
|
return Outcome.FAILURE
|
|
|
|
def _build_actions_from_session(
|
|
self,
|
|
scratchpad: list[str],
|
|
variables: dict[str, Any],
|
|
task_state: TaskState | None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Build action list from session data."""
|
|
actions: list[dict[str, Any]] = []
|
|
|
|
# Add scratchpad entries as actions
|
|
for i, entry in enumerate(scratchpad):
|
|
actions.append(
|
|
{
|
|
"step": i + 1,
|
|
"type": "reasoning",
|
|
"content": entry[:500], # Truncate long entries
|
|
}
|
|
)
|
|
|
|
# Add final state
|
|
if task_state:
|
|
actions.append(
|
|
{
|
|
"step": len(scratchpad) + 1,
|
|
"type": "final_state",
|
|
"current_step": task_state.current_step,
|
|
"total_steps": task_state.total_steps,
|
|
"progress": task_state.progress_percent,
|
|
"status": task_state.status,
|
|
}
|
|
)
|
|
|
|
return actions
|
|
|
|
def _build_context_summary(
|
|
self,
|
|
task_state: TaskState | None,
|
|
variables: dict[str, Any],
|
|
) -> str:
|
|
"""Build context summary from session data."""
|
|
parts = []
|
|
|
|
if task_state:
|
|
parts.append(f"Task: {task_state.description}")
|
|
parts.append(f"Progress: {task_state.progress_percent:.1f}%")
|
|
parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}")
|
|
|
|
# Include key variables
|
|
key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100}
|
|
if key_vars:
|
|
var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5])
|
|
parts.append(f"Variables: {var_str}")
|
|
|
|
return "; ".join(parts) if parts else "Session completed"
|
|
|
|
def _extract_session_lessons(
|
|
self,
|
|
task_state: TaskState | None,
|
|
outcome: Outcome,
|
|
) -> list[str]:
|
|
"""Extract lessons from session."""
|
|
lessons: list[str] = []
|
|
|
|
if task_state and task_state.status:
|
|
if outcome == Outcome.FAILURE:
|
|
lessons.append(
|
|
f"Task failed at step {task_state.current_step}: {task_state.status}"
|
|
)
|
|
elif outcome == Outcome.SUCCESS:
|
|
lessons.append(
|
|
f"Successfully completed in {task_state.current_step} steps"
|
|
)
|
|
|
|
return lessons
|
|
|
|
def _calculate_session_importance(
|
|
self,
|
|
task_state: TaskState | None,
|
|
outcome: Outcome,
|
|
actions: list[dict[str, Any]],
|
|
) -> float:
|
|
"""Calculate importance score for session."""
|
|
score = 0.5 # Base score
|
|
|
|
# Failures are important to learn from
|
|
if outcome == Outcome.FAILURE:
|
|
score += 0.3
|
|
|
|
# Many steps means complex task
|
|
if task_state and task_state.total_steps >= 5:
|
|
score += 0.1
|
|
|
|
# Many actions means detailed reasoning
|
|
if len(actions) >= 5:
|
|
score += 0.1
|
|
|
|
return min(1.0, score)
|
|
|
|
def _calculate_duration(self, task_state: TaskState | None) -> float:
|
|
"""Calculate session duration."""
|
|
if task_state is None:
|
|
return 0.0
|
|
|
|
if task_state.started_at and task_state.updated_at:
|
|
delta = task_state.updated_at - task_state.started_at
|
|
return delta.total_seconds()
|
|
|
|
return 0.0
|
|
|
|
# =========================================================================
|
|
# Episodic -> Semantic Consolidation
|
|
# =========================================================================
|
|
|
|
async def consolidate_episodes_to_facts(
|
|
self,
|
|
project_id: UUID,
|
|
since: datetime | None = None,
|
|
limit: int | None = None,
|
|
) -> ConsolidationResult:
|
|
"""
|
|
Extract facts from episodic memories to semantic memory.
|
|
|
|
Args:
|
|
project_id: Project to consolidate
|
|
since: Only process episodes since this time
|
|
limit: Maximum episodes to process
|
|
|
|
Returns:
|
|
ConsolidationResult with extraction statistics
|
|
"""
|
|
start_time = datetime.now(UTC)
|
|
result = ConsolidationResult(
|
|
source_type="episodic",
|
|
target_type="semantic",
|
|
)
|
|
|
|
try:
|
|
episodic = await self._get_episodic()
|
|
semantic = await self._get_semantic()
|
|
|
|
# Get episodes to process
|
|
since_time = since or datetime.now(UTC) - timedelta(days=1)
|
|
episodes = await episodic.get_recent(
|
|
project_id,
|
|
limit=limit or self._config.batch_size,
|
|
since=since_time,
|
|
)
|
|
|
|
for episode in episodes:
|
|
result.items_processed += 1
|
|
|
|
try:
|
|
# Extract facts using the extractor
|
|
extracted_facts = self._fact_extractor.extract_from_episode(episode)
|
|
|
|
for extracted_fact in extracted_facts:
|
|
if (
|
|
extracted_fact.confidence
|
|
< self._config.min_confidence_for_fact
|
|
):
|
|
result.items_skipped += 1
|
|
continue
|
|
|
|
# Create fact (store_fact handles deduplication/reinforcement)
|
|
fact_create = extracted_fact.to_fact_create(
|
|
project_id=project_id,
|
|
source_episode_ids=[episode.id],
|
|
)
|
|
|
|
# store_fact automatically reinforces if fact already exists
|
|
fact = await semantic.store_fact(fact_create)
|
|
|
|
# Check if this was a new fact or reinforced existing
|
|
if fact.reinforcement_count == 1:
|
|
result.items_created += 1
|
|
else:
|
|
result.items_updated += 1
|
|
|
|
except Exception as e:
|
|
result.errors.append(f"Episode {episode.id}: {e}")
|
|
logger.warning(
|
|
f"Failed to extract facts from episode {episode.id}: {e}"
|
|
)
|
|
|
|
except Exception as e:
|
|
result.errors.append(f"Consolidation failed: {e}")
|
|
logger.exception("Failed episodic -> semantic consolidation")
|
|
|
|
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
|
|
|
logger.info(
|
|
f"Episodic -> Semantic consolidation: "
|
|
f"{result.items_processed} processed, "
|
|
f"{result.items_created} created, "
|
|
f"{result.items_updated} reinforced"
|
|
)
|
|
|
|
return result
|
|
|
|
# =========================================================================
|
|
# Episodic -> Procedural Consolidation
|
|
# =========================================================================
|
|
|
|
async def consolidate_episodes_to_procedures(
|
|
self,
|
|
project_id: UUID,
|
|
agent_type_id: UUID | None = None,
|
|
since: datetime | None = None,
|
|
) -> ConsolidationResult:
|
|
"""
|
|
Learn procedures from patterns in episodic memories.
|
|
|
|
Identifies recurring successful patterns and creates/updates
|
|
procedures to capture them.
|
|
|
|
Args:
|
|
project_id: Project to consolidate
|
|
agent_type_id: Optional filter by agent type
|
|
since: Only process episodes since this time
|
|
|
|
Returns:
|
|
ConsolidationResult with procedure statistics
|
|
"""
|
|
start_time = datetime.now(UTC)
|
|
result = ConsolidationResult(
|
|
source_type="episodic",
|
|
target_type="procedural",
|
|
)
|
|
|
|
try:
|
|
episodic = await self._get_episodic()
|
|
procedural = await self._get_procedural()
|
|
|
|
# Get successful episodes
|
|
since_time = since or datetime.now(UTC) - timedelta(days=7)
|
|
episodes = await episodic.get_by_outcome(
|
|
project_id,
|
|
outcome=Outcome.SUCCESS,
|
|
limit=self._config.batch_size,
|
|
agent_instance_id=None, # Get all agent instances
|
|
)
|
|
|
|
# Group by task type
|
|
task_groups: dict[str, list[Episode]] = {}
|
|
for episode in episodes:
|
|
if episode.occurred_at >= since_time:
|
|
if episode.task_type not in task_groups:
|
|
task_groups[episode.task_type] = []
|
|
task_groups[episode.task_type].append(episode)
|
|
|
|
result.items_processed = len(episodes)
|
|
|
|
# Process each task type group
|
|
for task_type, group in task_groups.items():
|
|
if len(group) < self._config.min_episodes_for_procedure:
|
|
result.items_skipped += len(group)
|
|
continue
|
|
|
|
try:
|
|
procedure_result = await self._learn_procedure_from_episodes(
|
|
procedural,
|
|
project_id,
|
|
agent_type_id,
|
|
task_type,
|
|
group,
|
|
)
|
|
|
|
if procedure_result == "created":
|
|
result.items_created += 1
|
|
elif procedure_result == "updated":
|
|
result.items_updated += 1
|
|
else:
|
|
result.items_skipped += 1
|
|
|
|
except Exception as e:
|
|
result.errors.append(f"Task type '{task_type}': {e}")
|
|
logger.warning(f"Failed to learn procedure for '{task_type}': {e}")
|
|
|
|
except Exception as e:
|
|
result.errors.append(f"Consolidation failed: {e}")
|
|
logger.exception("Failed episodic -> procedural consolidation")
|
|
|
|
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
|
|
|
logger.info(
|
|
f"Episodic -> Procedural consolidation: "
|
|
f"{result.items_processed} processed, "
|
|
f"{result.items_created} created, "
|
|
f"{result.items_updated} updated"
|
|
)
|
|
|
|
return result
|
|
|
|
async def _learn_procedure_from_episodes(
|
|
self,
|
|
procedural: ProceduralMemory,
|
|
project_id: UUID,
|
|
agent_type_id: UUID | None,
|
|
task_type: str,
|
|
episodes: list[Episode],
|
|
) -> str:
|
|
"""Learn or update a procedure from a set of episodes."""
|
|
# Calculate success rate for this pattern
|
|
success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
|
total_count = len(episodes)
|
|
success_rate = success_count / total_count if total_count > 0 else 0
|
|
|
|
if success_rate < self._config.min_success_rate_for_procedure:
|
|
return "skipped"
|
|
|
|
# Extract common steps from episodes
|
|
steps = self._extract_common_steps(episodes)
|
|
|
|
if len(steps) < self._config.min_steps_for_procedure:
|
|
return "skipped"
|
|
|
|
# Check for existing procedure
|
|
matching = await procedural.find_matching(
|
|
context=task_type,
|
|
project_id=project_id,
|
|
agent_type_id=agent_type_id,
|
|
limit=1,
|
|
)
|
|
|
|
if matching:
|
|
# Update existing procedure with new success
|
|
await procedural.record_outcome(
|
|
matching[0].id,
|
|
success=True,
|
|
)
|
|
return "updated"
|
|
else:
|
|
# Create new procedure
|
|
# Note: success_count starts at 1 in record_procedure
|
|
procedure_data = ProcedureCreate(
|
|
project_id=project_id,
|
|
agent_type_id=agent_type_id,
|
|
name=f"Procedure for {task_type}",
|
|
trigger_pattern=task_type,
|
|
steps=steps,
|
|
)
|
|
await procedural.record_procedure(procedure_data)
|
|
return "created"
|
|
|
|
def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]:
|
|
"""Extract common action steps from multiple episodes."""
|
|
# Simple heuristic: take the steps from the most successful episode
|
|
# with the most detailed actions
|
|
|
|
best_episode = max(
|
|
episodes,
|
|
key=lambda e: (
|
|
e.outcome == Outcome.SUCCESS,
|
|
len(e.actions),
|
|
e.importance_score,
|
|
),
|
|
)
|
|
|
|
steps: list[dict[str, Any]] = []
|
|
for i, action in enumerate(best_episode.actions):
|
|
step = {
|
|
"order": i + 1,
|
|
"action": action.get("type", "action"),
|
|
"description": action.get("content", str(action))[:500],
|
|
"parameters": action,
|
|
}
|
|
steps.append(step)
|
|
|
|
return steps
|
|
|
|
# =========================================================================
|
|
# Memory Pruning
|
|
# =========================================================================
|
|
|
|
async def prune_old_episodes(
|
|
self,
|
|
project_id: UUID,
|
|
max_age_days: int | None = None,
|
|
min_importance: float | None = None,
|
|
) -> ConsolidationResult:
|
|
"""
|
|
Prune old, low-value episodes.
|
|
|
|
Args:
|
|
project_id: Project to prune
|
|
max_age_days: Maximum age in days (default from config)
|
|
min_importance: Minimum importance to keep (default from config)
|
|
|
|
Returns:
|
|
ConsolidationResult with pruning statistics
|
|
"""
|
|
start_time = datetime.now(UTC)
|
|
result = ConsolidationResult(
|
|
source_type="episodic",
|
|
target_type="pruned",
|
|
)
|
|
|
|
max_age = max_age_days or self._config.max_episode_age_days
|
|
min_imp = min_importance or self._config.min_importance_to_keep
|
|
cutoff_date = datetime.now(UTC) - timedelta(days=max_age)
|
|
|
|
try:
|
|
episodic = await self._get_episodic()
|
|
|
|
# Get old episodes
|
|
# Note: In production, this would use a more efficient query
|
|
all_episodes = await episodic.get_recent(
|
|
project_id,
|
|
limit=self._config.batch_size * 10,
|
|
since=cutoff_date - timedelta(days=365), # Search past year
|
|
)
|
|
|
|
for episode in all_episodes:
|
|
result.items_processed += 1
|
|
|
|
# Check if should be pruned
|
|
if not self._should_prune_episode(episode, cutoff_date, min_imp):
|
|
result.items_skipped += 1
|
|
continue
|
|
|
|
try:
|
|
deleted = await episodic.delete(episode.id)
|
|
if deleted:
|
|
result.items_pruned += 1
|
|
else:
|
|
result.items_skipped += 1
|
|
except Exception as e:
|
|
result.errors.append(f"Episode {episode.id}: {e}")
|
|
|
|
except Exception as e:
|
|
result.errors.append(f"Pruning failed: {e}")
|
|
logger.exception("Failed episode pruning")
|
|
|
|
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
|
|
|
logger.info(
|
|
f"Episode pruning: {result.items_processed} processed, "
|
|
f"{result.items_pruned} pruned"
|
|
)
|
|
|
|
return result
|
|
|
|
def _should_prune_episode(
|
|
self,
|
|
episode: Episode,
|
|
cutoff_date: datetime,
|
|
min_importance: float,
|
|
) -> bool:
|
|
"""Determine if an episode should be pruned."""
|
|
# Keep recent episodes
|
|
if episode.occurred_at >= cutoff_date:
|
|
return False
|
|
|
|
# Keep failures if configured
|
|
if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE:
|
|
return False
|
|
|
|
# Keep episodes with lessons if configured
|
|
if self._config.keep_all_with_lessons and episode.lessons_learned:
|
|
return False
|
|
|
|
# Keep high-importance episodes
|
|
if episode.importance_score >= min_importance:
|
|
return False
|
|
|
|
return True
|
|
|
|
# =========================================================================
|
|
# Nightly Consolidation
|
|
# =========================================================================
|
|
|
|
async def run_nightly_consolidation(
|
|
self,
|
|
project_id: UUID,
|
|
agent_type_id: UUID | None = None,
|
|
) -> NightlyConsolidationResult:
|
|
"""
|
|
Run full nightly consolidation workflow.
|
|
|
|
This includes:
|
|
1. Extract facts from recent episodes
|
|
2. Learn procedures from successful patterns
|
|
3. Prune old, low-value memories
|
|
|
|
Args:
|
|
project_id: Project to consolidate
|
|
agent_type_id: Optional agent type filter
|
|
|
|
Returns:
|
|
NightlyConsolidationResult with all outcomes
|
|
"""
|
|
result = NightlyConsolidationResult(started_at=datetime.now(UTC))
|
|
|
|
logger.info(f"Starting nightly consolidation for project {project_id}")
|
|
|
|
try:
|
|
# Step 1: Episodic -> Semantic (last 24 hours)
|
|
since_yesterday = datetime.now(UTC) - timedelta(days=1)
|
|
result.episodic_to_semantic = await self.consolidate_episodes_to_facts(
|
|
project_id=project_id,
|
|
since=since_yesterday,
|
|
)
|
|
result.total_facts_created = result.episodic_to_semantic.items_created
|
|
|
|
# Step 2: Episodic -> Procedural (last 7 days)
|
|
since_week = datetime.now(UTC) - timedelta(days=7)
|
|
result.episodic_to_procedural = (
|
|
await self.consolidate_episodes_to_procedures(
|
|
project_id=project_id,
|
|
agent_type_id=agent_type_id,
|
|
since=since_week,
|
|
)
|
|
)
|
|
result.total_procedures_created = (
|
|
result.episodic_to_procedural.items_created
|
|
)
|
|
|
|
# Step 3: Prune old memories
|
|
result.pruning = await self.prune_old_episodes(project_id=project_id)
|
|
result.total_pruned = result.pruning.items_pruned
|
|
|
|
# Calculate totals
|
|
result.total_episodes_processed = (
|
|
result.episodic_to_semantic.items_processed
|
|
if result.episodic_to_semantic
|
|
else 0
|
|
) + (
|
|
result.episodic_to_procedural.items_processed
|
|
if result.episodic_to_procedural
|
|
else 0
|
|
)
|
|
|
|
# Collect all errors
|
|
if result.episodic_to_semantic and result.episodic_to_semantic.errors:
|
|
result.errors.extend(result.episodic_to_semantic.errors)
|
|
if result.episodic_to_procedural and result.episodic_to_procedural.errors:
|
|
result.errors.extend(result.episodic_to_procedural.errors)
|
|
if result.pruning and result.pruning.errors:
|
|
result.errors.extend(result.pruning.errors)
|
|
|
|
except Exception as e:
|
|
result.errors.append(f"Nightly consolidation failed: {e}")
|
|
logger.exception("Nightly consolidation failed")
|
|
|
|
result.completed_at = datetime.now(UTC)
|
|
|
|
duration = (result.completed_at - result.started_at).total_seconds()
|
|
logger.info(
|
|
f"Nightly consolidation completed in {duration:.1f}s: "
|
|
f"{result.total_facts_created} facts, "
|
|
f"{result.total_procedures_created} procedures, "
|
|
f"{result.total_pruned} pruned"
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
# Singleton instance
|
|
_consolidation_service: MemoryConsolidationService | None = None
|
|
|
|
|
|
async def get_consolidation_service(
|
|
session: AsyncSession,
|
|
config: ConsolidationConfig | None = None,
|
|
) -> MemoryConsolidationService:
|
|
"""
|
|
Get or create the memory consolidation service.
|
|
|
|
Args:
|
|
session: Database session
|
|
config: Optional configuration
|
|
|
|
Returns:
|
|
MemoryConsolidationService instance
|
|
"""
|
|
global _consolidation_service
|
|
if _consolidation_service is None:
|
|
_consolidation_service = MemoryConsolidationService(
|
|
session=session, config=config
|
|
)
|
|
return _consolidation_service
|