forked from cardosofelipe/fast-next-template
feat(memory): implement memory consolidation service and tasks (#95)
- Add MemoryConsolidationService with Working→Episodic→Semantic/Procedural transfer - Add Celery tasks for session and nightly consolidation - Implement memory pruning with importance-based retention - Add comprehensive test suite (32 tests) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,29 @@
|
||||
# app/services/memory/consolidation/__init__.py
|
||||
"""
|
||||
Memory Consolidation
|
||||
Memory Consolidation.
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic
|
||||
- Episodic -> Semantic
|
||||
- Episodic -> Procedural
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (learn facts)
|
||||
- Episodic -> Procedural (learn procedures)
|
||||
|
||||
Also handles memory pruning and importance-based retention.
|
||||
"""
|
||||
|
||||
# Will be populated in #95
|
||||
from .service import (
|
||||
ConsolidationConfig,
|
||||
ConsolidationResult,
|
||||
MemoryConsolidationService,
|
||||
NightlyConsolidationResult,
|
||||
SessionConsolidationResult,
|
||||
get_consolidation_service,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConsolidationConfig",
|
||||
"ConsolidationResult",
|
||||
"MemoryConsolidationService",
|
||||
"NightlyConsolidationResult",
|
||||
"SessionConsolidationResult",
|
||||
"get_consolidation_service",
|
||||
]
|
||||
|
||||
918
backend/app/services/memory/consolidation/service.py
Normal file
918
backend/app/services/memory/consolidation/service.py
Normal file
@@ -0,0 +1,918 @@
|
||||
# app/services/memory/consolidation/service.py
|
||||
"""
|
||||
Memory Consolidation Service.
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (learn facts)
|
||||
- Episodic -> Procedural (learn procedures)
|
||||
|
||||
Also handles memory pruning and importance-based retention.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.episodic.memory import EpisodicMemory
|
||||
from app.services.memory.procedural.memory import ProceduralMemory
|
||||
from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor
|
||||
from app.services.memory.semantic.memory import SemanticMemory
|
||||
from app.services.memory.types import (
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Outcome,
|
||||
ProcedureCreate,
|
||||
TaskState,
|
||||
)
|
||||
from app.services.memory.working.memory import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationConfig:
|
||||
"""Configuration for memory consolidation."""
|
||||
|
||||
# Working -> Episodic thresholds
|
||||
min_steps_for_episode: int = 2
|
||||
min_duration_seconds: float = 5.0
|
||||
|
||||
# Episodic -> Semantic thresholds
|
||||
min_confidence_for_fact: float = 0.6
|
||||
max_facts_per_episode: int = 10
|
||||
reinforce_existing_facts: bool = True
|
||||
|
||||
# Episodic -> Procedural thresholds
|
||||
min_episodes_for_procedure: int = 3
|
||||
min_success_rate_for_procedure: float = 0.7
|
||||
min_steps_for_procedure: int = 2
|
||||
|
||||
# Pruning thresholds
|
||||
max_episode_age_days: int = 90
|
||||
min_importance_to_keep: float = 0.2
|
||||
keep_all_failures: bool = True
|
||||
keep_all_with_lessons: bool = True
|
||||
|
||||
# Batch sizes
|
||||
batch_size: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationResult:
|
||||
"""Result of a consolidation operation."""
|
||||
|
||||
source_type: str
|
||||
target_type: str
|
||||
items_processed: int = 0
|
||||
items_created: int = 0
|
||||
items_updated: int = 0
|
||||
items_skipped: int = 0
|
||||
items_pruned: int = 0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
duration_seconds: float = 0.0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"source_type": self.source_type,
|
||||
"target_type": self.target_type,
|
||||
"items_processed": self.items_processed,
|
||||
"items_created": self.items_created,
|
||||
"items_updated": self.items_updated,
|
||||
"items_skipped": self.items_skipped,
|
||||
"items_pruned": self.items_pruned,
|
||||
"errors": self.errors,
|
||||
"duration_seconds": self.duration_seconds,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConsolidationResult:
|
||||
"""Result of consolidating a session's working memory to episodic."""
|
||||
|
||||
session_id: str
|
||||
episode_created: bool = False
|
||||
episode_id: UUID | None = None
|
||||
scratchpad_entries: int = 0
|
||||
variables_captured: int = 0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NightlyConsolidationResult:
|
||||
"""Result of nightly consolidation run."""
|
||||
|
||||
started_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
episodic_to_semantic: ConsolidationResult | None = None
|
||||
episodic_to_procedural: ConsolidationResult | None = None
|
||||
pruning: ConsolidationResult | None = None
|
||||
total_episodes_processed: int = 0
|
||||
total_facts_created: int = 0
|
||||
total_procedures_created: int = 0
|
||||
total_pruned: int = 0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
"episodic_to_semantic": (
|
||||
self.episodic_to_semantic.to_dict()
|
||||
if self.episodic_to_semantic
|
||||
else None
|
||||
),
|
||||
"episodic_to_procedural": (
|
||||
self.episodic_to_procedural.to_dict()
|
||||
if self.episodic_to_procedural
|
||||
else None
|
||||
),
|
||||
"pruning": self.pruning.to_dict() if self.pruning else None,
|
||||
"total_episodes_processed": self.total_episodes_processed,
|
||||
"total_facts_created": self.total_facts_created,
|
||||
"total_procedures_created": self.total_procedures_created,
|
||||
"total_pruned": self.total_pruned,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class MemoryConsolidationService:
|
||||
"""
|
||||
Service for consolidating memories between tiers.
|
||||
|
||||
Responsibilities:
|
||||
- Transfer working memory to episodic at session end
|
||||
- Extract facts from episodes to semantic memory
|
||||
- Learn procedures from successful episode patterns
|
||||
- Prune old/low-value memories
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize consolidation service.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
config: Consolidation configuration
|
||||
embedding_generator: Optional embedding generator
|
||||
"""
|
||||
self._session = session
|
||||
self._config = config or ConsolidationConfig()
|
||||
self._embedding_generator = embedding_generator
|
||||
self._fact_extractor: FactExtractor = get_fact_extractor()
|
||||
|
||||
# Memory services (lazy initialized)
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
# =========================================================================
|
||||
# Working -> Episodic Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_session(
|
||||
self,
|
||||
working_memory: WorkingMemory,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str = "session_task",
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> SessionConsolidationResult:
|
||||
"""
|
||||
Consolidate a session's working memory to episodic memory.
|
||||
|
||||
Called at session end to transfer relevant session data
|
||||
into a persistent episode.
|
||||
|
||||
Args:
|
||||
working_memory: The session's working memory
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task performed
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
SessionConsolidationResult with outcome details
|
||||
"""
|
||||
result = SessionConsolidationResult(session_id=session_id)
|
||||
|
||||
try:
|
||||
# Get task state
|
||||
task_state = await working_memory.get_task_state()
|
||||
|
||||
# Check if there's enough content to consolidate
|
||||
if not self._should_consolidate_session(task_state):
|
||||
logger.debug(
|
||||
f"Skipping consolidation for session {session_id}: insufficient content"
|
||||
)
|
||||
return result
|
||||
|
||||
# Gather scratchpad entries
|
||||
scratchpad = await working_memory.get_scratchpad()
|
||||
result.scratchpad_entries = len(scratchpad)
|
||||
|
||||
# Gather user variables
|
||||
all_data = await working_memory.get_all()
|
||||
result.variables_captured = len(all_data)
|
||||
|
||||
# Determine outcome
|
||||
outcome = self._determine_session_outcome(task_state)
|
||||
|
||||
# Build actions from scratchpad and variables
|
||||
actions = self._build_actions_from_session(scratchpad, all_data, task_state)
|
||||
|
||||
# Build context summary
|
||||
context_summary = self._build_context_summary(task_state, all_data)
|
||||
|
||||
# Extract lessons learned
|
||||
lessons = self._extract_session_lessons(task_state, outcome)
|
||||
|
||||
# Calculate importance
|
||||
importance = self._calculate_session_importance(
|
||||
task_state, outcome, actions
|
||||
)
|
||||
|
||||
# Create episode
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_state.description
|
||||
if task_state
|
||||
else "Session task",
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=outcome,
|
||||
outcome_details=task_state.status if task_state else "",
|
||||
duration_seconds=self._calculate_duration(task_state),
|
||||
tokens_used=0, # Would need to track this in working memory
|
||||
lessons_learned=lessons,
|
||||
importance_score=importance,
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
|
||||
episodic = await self._get_episodic()
|
||||
episode = await episodic.record_episode(episode_data)
|
||||
|
||||
result.episode_created = True
|
||||
result.episode_id = episode.id
|
||||
|
||||
logger.info(
|
||||
f"Consolidated session {session_id} to episode {episode.id} "
|
||||
f"({len(actions)} actions, outcome={outcome.value})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result.error = str(e)
|
||||
logger.exception(f"Failed to consolidate session {session_id}")
|
||||
|
||||
return result
|
||||
|
||||
def _should_consolidate_session(self, task_state: TaskState | None) -> bool:
|
||||
"""Check if session has enough content to consolidate."""
|
||||
if task_state is None:
|
||||
return False
|
||||
|
||||
# Check minimum steps
|
||||
if task_state.current_step < self._config.min_steps_for_episode:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome:
|
||||
"""Determine outcome from task state."""
|
||||
if task_state is None:
|
||||
return Outcome.PARTIAL
|
||||
|
||||
status = task_state.status.lower() if task_state.status else ""
|
||||
progress = task_state.progress_percent
|
||||
|
||||
if "success" in status or "complete" in status or progress >= 100:
|
||||
return Outcome.SUCCESS
|
||||
if "fail" in status or "error" in status:
|
||||
return Outcome.FAILURE
|
||||
if progress >= 50:
|
||||
return Outcome.PARTIAL
|
||||
return Outcome.FAILURE
|
||||
|
||||
def _build_actions_from_session(
|
||||
self,
|
||||
scratchpad: list[str],
|
||||
variables: dict[str, Any],
|
||||
task_state: TaskState | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build action list from session data."""
|
||||
actions: list[dict[str, Any]] = []
|
||||
|
||||
# Add scratchpad entries as actions
|
||||
for i, entry in enumerate(scratchpad):
|
||||
actions.append(
|
||||
{
|
||||
"step": i + 1,
|
||||
"type": "reasoning",
|
||||
"content": entry[:500], # Truncate long entries
|
||||
}
|
||||
)
|
||||
|
||||
# Add final state
|
||||
if task_state:
|
||||
actions.append(
|
||||
{
|
||||
"step": len(scratchpad) + 1,
|
||||
"type": "final_state",
|
||||
"current_step": task_state.current_step,
|
||||
"total_steps": task_state.total_steps,
|
||||
"progress": task_state.progress_percent,
|
||||
"status": task_state.status,
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def _build_context_summary(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
variables: dict[str, Any],
|
||||
) -> str:
|
||||
"""Build context summary from session data."""
|
||||
parts = []
|
||||
|
||||
if task_state:
|
||||
parts.append(f"Task: {task_state.description}")
|
||||
parts.append(f"Progress: {task_state.progress_percent:.1f}%")
|
||||
parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}")
|
||||
|
||||
# Include key variables
|
||||
key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100}
|
||||
if key_vars:
|
||||
var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5])
|
||||
parts.append(f"Variables: {var_str}")
|
||||
|
||||
return "; ".join(parts) if parts else "Session completed"
|
||||
|
||||
def _extract_session_lessons(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
outcome: Outcome,
|
||||
) -> list[str]:
|
||||
"""Extract lessons from session."""
|
||||
lessons: list[str] = []
|
||||
|
||||
if task_state and task_state.status:
|
||||
if outcome == Outcome.FAILURE:
|
||||
lessons.append(
|
||||
f"Task failed at step {task_state.current_step}: {task_state.status}"
|
||||
)
|
||||
elif outcome == Outcome.SUCCESS:
|
||||
lessons.append(
|
||||
f"Successfully completed in {task_state.current_step} steps"
|
||||
)
|
||||
|
||||
return lessons
|
||||
|
||||
def _calculate_session_importance(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
outcome: Outcome,
|
||||
actions: list[dict[str, Any]],
|
||||
) -> float:
|
||||
"""Calculate importance score for session."""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Failures are important to learn from
|
||||
if outcome == Outcome.FAILURE:
|
||||
score += 0.3
|
||||
|
||||
# Many steps means complex task
|
||||
if task_state and task_state.total_steps >= 5:
|
||||
score += 0.1
|
||||
|
||||
# Many actions means detailed reasoning
|
||||
if len(actions) >= 5:
|
||||
score += 0.1
|
||||
|
||||
return min(1.0, score)
|
||||
|
||||
def _calculate_duration(self, task_state: TaskState | None) -> float:
|
||||
"""Calculate session duration."""
|
||||
if task_state is None:
|
||||
return 0.0
|
||||
|
||||
if task_state.started_at and task_state.updated_at:
|
||||
delta = task_state.updated_at - task_state.started_at
|
||||
return delta.total_seconds()
|
||||
|
||||
return 0.0
|
||||
|
||||
# =========================================================================
|
||||
# Episodic -> Semantic Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_episodes_to_facts(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Extract facts from episodic memories to semantic memory.
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
since: Only process episodes since this time
|
||||
limit: Maximum episodes to process
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with extraction statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
# Get episodes to process
|
||||
since_time = since or datetime.now(UTC) - timedelta(days=1)
|
||||
episodes = await episodic.get_recent(
|
||||
project_id,
|
||||
limit=limit or self._config.batch_size,
|
||||
since=since_time,
|
||||
)
|
||||
|
||||
for episode in episodes:
|
||||
result.items_processed += 1
|
||||
|
||||
try:
|
||||
# Extract facts using the extractor
|
||||
extracted_facts = self._fact_extractor.extract_from_episode(episode)
|
||||
|
||||
for extracted_fact in extracted_facts:
|
||||
if (
|
||||
extracted_fact.confidence
|
||||
< self._config.min_confidence_for_fact
|
||||
):
|
||||
result.items_skipped += 1
|
||||
continue
|
||||
|
||||
# Create fact (store_fact handles deduplication/reinforcement)
|
||||
fact_create = extracted_fact.to_fact_create(
|
||||
project_id=project_id,
|
||||
source_episode_ids=[episode.id],
|
||||
)
|
||||
|
||||
# store_fact automatically reinforces if fact already exists
|
||||
fact = await semantic.store_fact(fact_create)
|
||||
|
||||
# Check if this was a new fact or reinforced existing
|
||||
if fact.reinforcement_count == 1:
|
||||
result.items_created += 1
|
||||
else:
|
||||
result.items_updated += 1
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Episode {episode.id}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to extract facts from episode {episode.id}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Consolidation failed: {e}")
|
||||
logger.exception("Failed episodic -> semantic consolidation")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episodic -> Semantic consolidation: "
|
||||
f"{result.items_processed} processed, "
|
||||
f"{result.items_created} created, "
|
||||
f"{result.items_updated} reinforced"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# =========================================================================
|
||||
# Episodic -> Procedural Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_episodes_to_procedures(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Learn procedures from patterns in episodic memories.
|
||||
|
||||
Identifies recurring successful patterns and creates/updates
|
||||
procedures to capture them.
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
agent_type_id: Optional filter by agent type
|
||||
since: Only process episodes since this time
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with procedure statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
# Get successful episodes
|
||||
since_time = since or datetime.now(UTC) - timedelta(days=7)
|
||||
episodes = await episodic.get_by_outcome(
|
||||
project_id,
|
||||
outcome=Outcome.SUCCESS,
|
||||
limit=self._config.batch_size,
|
||||
agent_instance_id=None, # Get all agent instances
|
||||
)
|
||||
|
||||
# Group by task type
|
||||
task_groups: dict[str, list[Episode]] = {}
|
||||
for episode in episodes:
|
||||
if episode.occurred_at >= since_time:
|
||||
if episode.task_type not in task_groups:
|
||||
task_groups[episode.task_type] = []
|
||||
task_groups[episode.task_type].append(episode)
|
||||
|
||||
result.items_processed = len(episodes)
|
||||
|
||||
# Process each task type group
|
||||
for task_type, group in task_groups.items():
|
||||
if len(group) < self._config.min_episodes_for_procedure:
|
||||
result.items_skipped += len(group)
|
||||
continue
|
||||
|
||||
try:
|
||||
procedure_result = await self._learn_procedure_from_episodes(
|
||||
procedural,
|
||||
project_id,
|
||||
agent_type_id,
|
||||
task_type,
|
||||
group,
|
||||
)
|
||||
|
||||
if procedure_result == "created":
|
||||
result.items_created += 1
|
||||
elif procedure_result == "updated":
|
||||
result.items_updated += 1
|
||||
else:
|
||||
result.items_skipped += 1
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Task type '{task_type}': {e}")
|
||||
logger.warning(f"Failed to learn procedure for '{task_type}': {e}")
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Consolidation failed: {e}")
|
||||
logger.exception("Failed episodic -> procedural consolidation")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episodic -> Procedural consolidation: "
|
||||
f"{result.items_processed} processed, "
|
||||
f"{result.items_created} created, "
|
||||
f"{result.items_updated} updated"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _learn_procedure_from_episodes(
|
||||
self,
|
||||
procedural: ProceduralMemory,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
task_type: str,
|
||||
episodes: list[Episode],
|
||||
) -> str:
|
||||
"""Learn or update a procedure from a set of episodes."""
|
||||
# Calculate success rate for this pattern
|
||||
success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
||||
total_count = len(episodes)
|
||||
success_rate = success_count / total_count if total_count > 0 else 0
|
||||
|
||||
if success_rate < self._config.min_success_rate_for_procedure:
|
||||
return "skipped"
|
||||
|
||||
# Extract common steps from episodes
|
||||
steps = self._extract_common_steps(episodes)
|
||||
|
||||
if len(steps) < self._config.min_steps_for_procedure:
|
||||
return "skipped"
|
||||
|
||||
# Check for existing procedure
|
||||
matching = await procedural.find_matching(
|
||||
context=task_type,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
if matching:
|
||||
# Update existing procedure with new success
|
||||
await procedural.record_outcome(
|
||||
matching[0].id,
|
||||
success=True,
|
||||
)
|
||||
return "updated"
|
||||
else:
|
||||
# Create new procedure
|
||||
# Note: success_count starts at 1 in record_procedure
|
||||
procedure_data = ProcedureCreate(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
name=f"Procedure for {task_type}",
|
||||
trigger_pattern=task_type,
|
||||
steps=steps,
|
||||
)
|
||||
await procedural.record_procedure(procedure_data)
|
||||
return "created"
|
||||
|
||||
def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]:
|
||||
"""Extract common action steps from multiple episodes."""
|
||||
# Simple heuristic: take the steps from the most successful episode
|
||||
# with the most detailed actions
|
||||
|
||||
best_episode = max(
|
||||
episodes,
|
||||
key=lambda e: (
|
||||
e.outcome == Outcome.SUCCESS,
|
||||
len(e.actions),
|
||||
e.importance_score,
|
||||
),
|
||||
)
|
||||
|
||||
steps: list[dict[str, Any]] = []
|
||||
for i, action in enumerate(best_episode.actions):
|
||||
step = {
|
||||
"order": i + 1,
|
||||
"action": action.get("type", "action"),
|
||||
"description": action.get("content", str(action))[:500],
|
||||
"parameters": action,
|
||||
}
|
||||
steps.append(step)
|
||||
|
||||
return steps
|
||||
|
||||
# =========================================================================
|
||||
# Memory Pruning
|
||||
# =========================================================================
|
||||
|
||||
async def prune_old_episodes(
|
||||
self,
|
||||
project_id: UUID,
|
||||
max_age_days: int | None = None,
|
||||
min_importance: float | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Prune old, low-value episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to prune
|
||||
max_age_days: Maximum age in days (default from config)
|
||||
min_importance: Minimum importance to keep (default from config)
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with pruning statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
)
|
||||
|
||||
max_age = max_age_days or self._config.max_episode_age_days
|
||||
min_imp = min_importance or self._config.min_importance_to_keep
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=max_age)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Get old episodes
|
||||
# Note: In production, this would use a more efficient query
|
||||
all_episodes = await episodic.get_recent(
|
||||
project_id,
|
||||
limit=self._config.batch_size * 10,
|
||||
since=cutoff_date - timedelta(days=365), # Search past year
|
||||
)
|
||||
|
||||
for episode in all_episodes:
|
||||
result.items_processed += 1
|
||||
|
||||
# Check if should be pruned
|
||||
if not self._should_prune_episode(episode, cutoff_date, min_imp):
|
||||
result.items_skipped += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted = await episodic.delete(episode.id)
|
||||
if deleted:
|
||||
result.items_pruned += 1
|
||||
else:
|
||||
result.items_skipped += 1
|
||||
except Exception as e:
|
||||
result.errors.append(f"Episode {episode.id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Pruning failed: {e}")
|
||||
logger.exception("Failed episode pruning")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episode pruning: {result.items_processed} processed, "
|
||||
f"{result.items_pruned} pruned"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _should_prune_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
cutoff_date: datetime,
|
||||
min_importance: float,
|
||||
) -> bool:
|
||||
"""Determine if an episode should be pruned."""
|
||||
# Keep recent episodes
|
||||
if episode.occurred_at >= cutoff_date:
|
||||
return False
|
||||
|
||||
# Keep failures if configured
|
||||
if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE:
|
||||
return False
|
||||
|
||||
# Keep episodes with lessons if configured
|
||||
if self._config.keep_all_with_lessons and episode.lessons_learned:
|
||||
return False
|
||||
|
||||
# Keep high-importance episodes
|
||||
if episode.importance_score >= min_importance:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Nightly Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def run_nightly_consolidation(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> NightlyConsolidationResult:
|
||||
"""
|
||||
Run full nightly consolidation workflow.
|
||||
|
||||
This includes:
|
||||
1. Extract facts from recent episodes
|
||||
2. Learn procedures from successful patterns
|
||||
3. Prune old, low-value memories
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
NightlyConsolidationResult with all outcomes
|
||||
"""
|
||||
result = NightlyConsolidationResult(started_at=datetime.now(UTC))
|
||||
|
||||
logger.info(f"Starting nightly consolidation for project {project_id}")
|
||||
|
||||
try:
|
||||
# Step 1: Episodic -> Semantic (last 24 hours)
|
||||
since_yesterday = datetime.now(UTC) - timedelta(days=1)
|
||||
result.episodic_to_semantic = await self.consolidate_episodes_to_facts(
|
||||
project_id=project_id,
|
||||
since=since_yesterday,
|
||||
)
|
||||
result.total_facts_created = result.episodic_to_semantic.items_created
|
||||
|
||||
# Step 2: Episodic -> Procedural (last 7 days)
|
||||
since_week = datetime.now(UTC) - timedelta(days=7)
|
||||
result.episodic_to_procedural = (
|
||||
await self.consolidate_episodes_to_procedures(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
since=since_week,
|
||||
)
|
||||
)
|
||||
result.total_procedures_created = (
|
||||
result.episodic_to_procedural.items_created
|
||||
)
|
||||
|
||||
# Step 3: Prune old memories
|
||||
result.pruning = await self.prune_old_episodes(project_id=project_id)
|
||||
result.total_pruned = result.pruning.items_pruned
|
||||
|
||||
# Calculate totals
|
||||
result.total_episodes_processed = (
|
||||
result.episodic_to_semantic.items_processed
|
||||
if result.episodic_to_semantic
|
||||
else 0
|
||||
) + (
|
||||
result.episodic_to_procedural.items_processed
|
||||
if result.episodic_to_procedural
|
||||
else 0
|
||||
)
|
||||
|
||||
# Collect all errors
|
||||
if result.episodic_to_semantic and result.episodic_to_semantic.errors:
|
||||
result.errors.extend(result.episodic_to_semantic.errors)
|
||||
if result.episodic_to_procedural and result.episodic_to_procedural.errors:
|
||||
result.errors.extend(result.episodic_to_procedural.errors)
|
||||
if result.pruning and result.pruning.errors:
|
||||
result.errors.extend(result.pruning.errors)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Nightly consolidation failed: {e}")
|
||||
logger.exception("Nightly consolidation failed")
|
||||
|
||||
result.completed_at = datetime.now(UTC)
|
||||
|
||||
duration = (result.completed_at - result.started_at).total_seconds()
|
||||
logger.info(
|
||||
f"Nightly consolidation completed in {duration:.1f}s: "
|
||||
f"{result.total_facts_created} facts, "
|
||||
f"{result.total_procedures_created} procedures, "
|
||||
f"{result.total_pruned} pruned"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_consolidation_service: MemoryConsolidationService | None = None
|
||||
|
||||
|
||||
async def get_consolidation_service(
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
) -> MemoryConsolidationService:
|
||||
"""
|
||||
Get or create the memory consolidation service.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
config: Optional configuration
|
||||
|
||||
Returns:
|
||||
MemoryConsolidationService instance
|
||||
"""
|
||||
global _consolidation_service
|
||||
if _consolidation_service is None:
|
||||
_consolidation_service = MemoryConsolidationService(
|
||||
session=session, config=config
|
||||
)
|
||||
return _consolidation_service
|
||||
@@ -10,14 +10,16 @@ Modules:
|
||||
sync: Issue synchronization tasks (incremental/full sync, webhooks)
|
||||
workflow: Workflow state management tasks
|
||||
cost: Cost tracking and budget monitoring tasks
|
||||
memory_consolidation: Memory consolidation tasks
|
||||
"""
|
||||
|
||||
from app.tasks import agent, cost, git, sync, workflow
|
||||
from app.tasks import agent, cost, git, memory_consolidation, sync, workflow
|
||||
|
||||
__all__ = [
|
||||
"agent",
|
||||
"cost",
|
||||
"git",
|
||||
"memory_consolidation",
|
||||
"sync",
|
||||
"workflow",
|
||||
]
|
||||
|
||||
234
backend/app/tasks/memory_consolidation.py
Normal file
234
backend/app/tasks/memory_consolidation.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# app/tasks/memory_consolidation.py
|
||||
"""
|
||||
Memory consolidation Celery tasks.
|
||||
|
||||
Handles scheduled and on-demand memory consolidation:
|
||||
- Session consolidation (on session end)
|
||||
- Nightly consolidation (scheduled)
|
||||
- On-demand project consolidation
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_session",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_session(
|
||||
self,
|
||||
project_id: str,
|
||||
session_id: str,
|
||||
task_type: str = "session_task",
|
||||
agent_instance_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Consolidate a session's working memory to episodic memory.
|
||||
|
||||
This task is triggered when an agent session ends to transfer
|
||||
relevant session data into persistent episodic memory.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
session_id: Session identifier
|
||||
task_type: Type of task performed
|
||||
agent_instance_id: Optional agent instance UUID
|
||||
agent_type_id: Optional agent type UUID
|
||||
|
||||
Returns:
|
||||
dict with consolidation results
|
||||
"""
|
||||
logger.info(f"Consolidating session {session_id} for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# This will involve:
|
||||
# 1. Getting database session from async context
|
||||
# 2. Loading working memory for session
|
||||
# 3. Calling consolidation service
|
||||
# 4. Returning results
|
||||
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"session_id": session_id,
|
||||
"episode_created": False,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.run_nightly_consolidation",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def run_nightly_consolidation(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_type_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run nightly memory consolidation for a project.
|
||||
|
||||
This task performs the full consolidation workflow:
|
||||
1. Extract facts from recent episodes to semantic memory
|
||||
2. Learn procedures from successful episode patterns
|
||||
3. Prune old, low-value memories
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project to consolidate
|
||||
agent_type_id: Optional agent type to filter by
|
||||
|
||||
Returns:
|
||||
dict with consolidation results
|
||||
"""
|
||||
logger.info(f"Running nightly consolidation for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# This will involve:
|
||||
# 1. Getting database session from async context
|
||||
# 2. Creating consolidation service instance
|
||||
# 3. Running run_nightly_consolidation
|
||||
# 4. Returning results
|
||||
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"total_facts_created": 0,
|
||||
"total_procedures_created": 0,
|
||||
"total_pruned": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_episodes_to_facts",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_episodes_to_facts(
|
||||
self,
|
||||
project_id: str,
|
||||
since_hours: int = 24,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract facts from episodic memories.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
since_hours: Process episodes from last N hours
|
||||
limit: Maximum episodes to process
|
||||
|
||||
Returns:
|
||||
dict with extraction results
|
||||
"""
|
||||
logger.info(f"Consolidating episodes to facts for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_processed": 0,
|
||||
"items_created": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_episodes_to_procedures",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_episodes_to_procedures(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_type_id: str | None = None,
|
||||
since_days: int = 7,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Learn procedures from episodic patterns.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
agent_type_id: Optional agent type filter
|
||||
since_days: Process episodes from last N days
|
||||
|
||||
Returns:
|
||||
dict with procedure learning results
|
||||
"""
|
||||
logger.info(f"Consolidating episodes to procedures for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_processed": 0,
|
||||
"items_created": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.prune_old_memories",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def prune_old_memories(
|
||||
self,
|
||||
project_id: str,
|
||||
max_age_days: int = 90,
|
||||
min_importance: float = 0.2,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Prune old, low-value memories.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
max_age_days: Maximum age in days
|
||||
min_importance: Minimum importance to keep
|
||||
|
||||
Returns:
|
||||
dict with pruning results
|
||||
"""
|
||||
logger.info(f"Pruning old memories for project {project_id}")
|
||||
|
||||
# TODO: Implement actual pruning
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_pruned": 0,
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Celery Beat Schedule Configuration
|
||||
# =========================================================================
|
||||
|
||||
# This would typically be configured in celery_app.py or a separate config file
|
||||
# Example schedule for nightly consolidation:
|
||||
#
|
||||
# app.conf.beat_schedule = {
|
||||
# 'nightly-memory-consolidation': {
|
||||
# 'task': 'app.tasks.memory_consolidation.run_nightly_consolidation',
|
||||
# 'schedule': crontab(hour=2, minute=0), # 2 AM daily
|
||||
# 'args': (None,), # Will process all projects
|
||||
# },
|
||||
# }
|
||||
Reference in New Issue
Block a user