forked from cardosofelipe/fast-next-template
feat(memory): implement memory consolidation service and tasks (#95)
- Add MemoryConsolidationService with Working→Episodic→Semantic/Procedural transfer - Add Celery tasks for session and nightly consolidation - Implement memory pruning with importance-based retention - Add comprehensive test suite (32 tests) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,29 @@
|
||||
# app/services/memory/consolidation/__init__.py
|
||||
"""
|
||||
Memory Consolidation
|
||||
Memory Consolidation.
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic
|
||||
- Episodic -> Semantic
|
||||
- Episodic -> Procedural
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (learn facts)
|
||||
- Episodic -> Procedural (learn procedures)
|
||||
|
||||
Also handles memory pruning and importance-based retention.
|
||||
"""
|
||||
|
||||
# Will be populated in #95
|
||||
from .service import (
|
||||
ConsolidationConfig,
|
||||
ConsolidationResult,
|
||||
MemoryConsolidationService,
|
||||
NightlyConsolidationResult,
|
||||
SessionConsolidationResult,
|
||||
get_consolidation_service,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConsolidationConfig",
|
||||
"ConsolidationResult",
|
||||
"MemoryConsolidationService",
|
||||
"NightlyConsolidationResult",
|
||||
"SessionConsolidationResult",
|
||||
"get_consolidation_service",
|
||||
]
|
||||
|
||||
918
backend/app/services/memory/consolidation/service.py
Normal file
918
backend/app/services/memory/consolidation/service.py
Normal file
@@ -0,0 +1,918 @@
|
||||
# app/services/memory/consolidation/service.py
|
||||
"""
|
||||
Memory Consolidation Service.
|
||||
|
||||
Transfers and extracts knowledge between memory tiers:
|
||||
- Working -> Episodic (session end)
|
||||
- Episodic -> Semantic (learn facts)
|
||||
- Episodic -> Procedural (learn procedures)
|
||||
|
||||
Also handles memory pruning and importance-based retention.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.memory.episodic.memory import EpisodicMemory
|
||||
from app.services.memory.procedural.memory import ProceduralMemory
|
||||
from app.services.memory.semantic.extraction import FactExtractor, get_fact_extractor
|
||||
from app.services.memory.semantic.memory import SemanticMemory
|
||||
from app.services.memory.types import (
|
||||
Episode,
|
||||
EpisodeCreate,
|
||||
Outcome,
|
||||
ProcedureCreate,
|
||||
TaskState,
|
||||
)
|
||||
from app.services.memory.working.memory import WorkingMemory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationConfig:
|
||||
"""Configuration for memory consolidation."""
|
||||
|
||||
# Working -> Episodic thresholds
|
||||
min_steps_for_episode: int = 2
|
||||
min_duration_seconds: float = 5.0
|
||||
|
||||
# Episodic -> Semantic thresholds
|
||||
min_confidence_for_fact: float = 0.6
|
||||
max_facts_per_episode: int = 10
|
||||
reinforce_existing_facts: bool = True
|
||||
|
||||
# Episodic -> Procedural thresholds
|
||||
min_episodes_for_procedure: int = 3
|
||||
min_success_rate_for_procedure: float = 0.7
|
||||
min_steps_for_procedure: int = 2
|
||||
|
||||
# Pruning thresholds
|
||||
max_episode_age_days: int = 90
|
||||
min_importance_to_keep: float = 0.2
|
||||
keep_all_failures: bool = True
|
||||
keep_all_with_lessons: bool = True
|
||||
|
||||
# Batch sizes
|
||||
batch_size: int = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsolidationResult:
|
||||
"""Result of a consolidation operation."""
|
||||
|
||||
source_type: str
|
||||
target_type: str
|
||||
items_processed: int = 0
|
||||
items_created: int = 0
|
||||
items_updated: int = 0
|
||||
items_skipped: int = 0
|
||||
items_pruned: int = 0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
duration_seconds: float = 0.0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"source_type": self.source_type,
|
||||
"target_type": self.target_type,
|
||||
"items_processed": self.items_processed,
|
||||
"items_created": self.items_created,
|
||||
"items_updated": self.items_updated,
|
||||
"items_skipped": self.items_skipped,
|
||||
"items_pruned": self.items_pruned,
|
||||
"errors": self.errors,
|
||||
"duration_seconds": self.duration_seconds,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionConsolidationResult:
|
||||
"""Result of consolidating a session's working memory to episodic."""
|
||||
|
||||
session_id: str
|
||||
episode_created: bool = False
|
||||
episode_id: UUID | None = None
|
||||
scratchpad_entries: int = 0
|
||||
variables_captured: int = 0
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class NightlyConsolidationResult:
|
||||
"""Result of nightly consolidation run."""
|
||||
|
||||
started_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
episodic_to_semantic: ConsolidationResult | None = None
|
||||
episodic_to_procedural: ConsolidationResult | None = None
|
||||
pruning: ConsolidationResult | None = None
|
||||
total_episodes_processed: int = 0
|
||||
total_facts_created: int = 0
|
||||
total_procedures_created: int = 0
|
||||
total_pruned: int = 0
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"completed_at": self.completed_at.isoformat()
|
||||
if self.completed_at
|
||||
else None,
|
||||
"episodic_to_semantic": (
|
||||
self.episodic_to_semantic.to_dict()
|
||||
if self.episodic_to_semantic
|
||||
else None
|
||||
),
|
||||
"episodic_to_procedural": (
|
||||
self.episodic_to_procedural.to_dict()
|
||||
if self.episodic_to_procedural
|
||||
else None
|
||||
),
|
||||
"pruning": self.pruning.to_dict() if self.pruning else None,
|
||||
"total_episodes_processed": self.total_episodes_processed,
|
||||
"total_facts_created": self.total_facts_created,
|
||||
"total_procedures_created": self.total_procedures_created,
|
||||
"total_pruned": self.total_pruned,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class MemoryConsolidationService:
|
||||
"""
|
||||
Service for consolidating memories between tiers.
|
||||
|
||||
Responsibilities:
|
||||
- Transfer working memory to episodic at session end
|
||||
- Extract facts from episodes to semantic memory
|
||||
- Learn procedures from successful episode patterns
|
||||
- Prune old/low-value memories
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize consolidation service.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
config: Consolidation configuration
|
||||
embedding_generator: Optional embedding generator
|
||||
"""
|
||||
self._session = session
|
||||
self._config = config or ConsolidationConfig()
|
||||
self._embedding_generator = embedding_generator
|
||||
self._fact_extractor: FactExtractor = get_fact_extractor()
|
||||
|
||||
# Memory services (lazy initialized)
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
if self._episodic is None:
|
||||
self._episodic = await EpisodicMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._episodic
|
||||
|
||||
async def _get_semantic(self) -> SemanticMemory:
|
||||
"""Get or create semantic memory service."""
|
||||
if self._semantic is None:
|
||||
self._semantic = await SemanticMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._semantic
|
||||
|
||||
async def _get_procedural(self) -> ProceduralMemory:
|
||||
"""Get or create procedural memory service."""
|
||||
if self._procedural is None:
|
||||
self._procedural = await ProceduralMemory.create(
|
||||
self._session, self._embedding_generator
|
||||
)
|
||||
return self._procedural
|
||||
|
||||
# =========================================================================
|
||||
# Working -> Episodic Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_session(
|
||||
self,
|
||||
working_memory: WorkingMemory,
|
||||
project_id: UUID,
|
||||
session_id: str,
|
||||
task_type: str = "session_task",
|
||||
agent_instance_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> SessionConsolidationResult:
|
||||
"""
|
||||
Consolidate a session's working memory to episodic memory.
|
||||
|
||||
Called at session end to transfer relevant session data
|
||||
into a persistent episode.
|
||||
|
||||
Args:
|
||||
working_memory: The session's working memory
|
||||
project_id: Project ID
|
||||
session_id: Session ID
|
||||
task_type: Type of task performed
|
||||
agent_instance_id: Optional agent instance
|
||||
agent_type_id: Optional agent type
|
||||
|
||||
Returns:
|
||||
SessionConsolidationResult with outcome details
|
||||
"""
|
||||
result = SessionConsolidationResult(session_id=session_id)
|
||||
|
||||
try:
|
||||
# Get task state
|
||||
task_state = await working_memory.get_task_state()
|
||||
|
||||
# Check if there's enough content to consolidate
|
||||
if not self._should_consolidate_session(task_state):
|
||||
logger.debug(
|
||||
f"Skipping consolidation for session {session_id}: insufficient content"
|
||||
)
|
||||
return result
|
||||
|
||||
# Gather scratchpad entries
|
||||
scratchpad = await working_memory.get_scratchpad()
|
||||
result.scratchpad_entries = len(scratchpad)
|
||||
|
||||
# Gather user variables
|
||||
all_data = await working_memory.get_all()
|
||||
result.variables_captured = len(all_data)
|
||||
|
||||
# Determine outcome
|
||||
outcome = self._determine_session_outcome(task_state)
|
||||
|
||||
# Build actions from scratchpad and variables
|
||||
actions = self._build_actions_from_session(scratchpad, all_data, task_state)
|
||||
|
||||
# Build context summary
|
||||
context_summary = self._build_context_summary(task_state, all_data)
|
||||
|
||||
# Extract lessons learned
|
||||
lessons = self._extract_session_lessons(task_state, outcome)
|
||||
|
||||
# Calculate importance
|
||||
importance = self._calculate_session_importance(
|
||||
task_state, outcome, actions
|
||||
)
|
||||
|
||||
# Create episode
|
||||
episode_data = EpisodeCreate(
|
||||
project_id=project_id,
|
||||
session_id=session_id,
|
||||
task_type=task_type,
|
||||
task_description=task_state.description
|
||||
if task_state
|
||||
else "Session task",
|
||||
actions=actions,
|
||||
context_summary=context_summary,
|
||||
outcome=outcome,
|
||||
outcome_details=task_state.status if task_state else "",
|
||||
duration_seconds=self._calculate_duration(task_state),
|
||||
tokens_used=0, # Would need to track this in working memory
|
||||
lessons_learned=lessons,
|
||||
importance_score=importance,
|
||||
agent_instance_id=agent_instance_id,
|
||||
agent_type_id=agent_type_id,
|
||||
)
|
||||
|
||||
episodic = await self._get_episodic()
|
||||
episode = await episodic.record_episode(episode_data)
|
||||
|
||||
result.episode_created = True
|
||||
result.episode_id = episode.id
|
||||
|
||||
logger.info(
|
||||
f"Consolidated session {session_id} to episode {episode.id} "
|
||||
f"({len(actions)} actions, outcome={outcome.value})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result.error = str(e)
|
||||
logger.exception(f"Failed to consolidate session {session_id}")
|
||||
|
||||
return result
|
||||
|
||||
def _should_consolidate_session(self, task_state: TaskState | None) -> bool:
|
||||
"""Check if session has enough content to consolidate."""
|
||||
if task_state is None:
|
||||
return False
|
||||
|
||||
# Check minimum steps
|
||||
if task_state.current_step < self._config.min_steps_for_episode:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _determine_session_outcome(self, task_state: TaskState | None) -> Outcome:
|
||||
"""Determine outcome from task state."""
|
||||
if task_state is None:
|
||||
return Outcome.PARTIAL
|
||||
|
||||
status = task_state.status.lower() if task_state.status else ""
|
||||
progress = task_state.progress_percent
|
||||
|
||||
if "success" in status or "complete" in status or progress >= 100:
|
||||
return Outcome.SUCCESS
|
||||
if "fail" in status or "error" in status:
|
||||
return Outcome.FAILURE
|
||||
if progress >= 50:
|
||||
return Outcome.PARTIAL
|
||||
return Outcome.FAILURE
|
||||
|
||||
def _build_actions_from_session(
|
||||
self,
|
||||
scratchpad: list[str],
|
||||
variables: dict[str, Any],
|
||||
task_state: TaskState | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build action list from session data."""
|
||||
actions: list[dict[str, Any]] = []
|
||||
|
||||
# Add scratchpad entries as actions
|
||||
for i, entry in enumerate(scratchpad):
|
||||
actions.append(
|
||||
{
|
||||
"step": i + 1,
|
||||
"type": "reasoning",
|
||||
"content": entry[:500], # Truncate long entries
|
||||
}
|
||||
)
|
||||
|
||||
# Add final state
|
||||
if task_state:
|
||||
actions.append(
|
||||
{
|
||||
"step": len(scratchpad) + 1,
|
||||
"type": "final_state",
|
||||
"current_step": task_state.current_step,
|
||||
"total_steps": task_state.total_steps,
|
||||
"progress": task_state.progress_percent,
|
||||
"status": task_state.status,
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
def _build_context_summary(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
variables: dict[str, Any],
|
||||
) -> str:
|
||||
"""Build context summary from session data."""
|
||||
parts = []
|
||||
|
||||
if task_state:
|
||||
parts.append(f"Task: {task_state.description}")
|
||||
parts.append(f"Progress: {task_state.progress_percent:.1f}%")
|
||||
parts.append(f"Steps: {task_state.current_step}/{task_state.total_steps}")
|
||||
|
||||
# Include key variables
|
||||
key_vars = {k: v for k, v in variables.items() if len(str(v)) < 100}
|
||||
if key_vars:
|
||||
var_str = ", ".join(f"{k}={v}" for k, v in list(key_vars.items())[:5])
|
||||
parts.append(f"Variables: {var_str}")
|
||||
|
||||
return "; ".join(parts) if parts else "Session completed"
|
||||
|
||||
def _extract_session_lessons(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
outcome: Outcome,
|
||||
) -> list[str]:
|
||||
"""Extract lessons from session."""
|
||||
lessons: list[str] = []
|
||||
|
||||
if task_state and task_state.status:
|
||||
if outcome == Outcome.FAILURE:
|
||||
lessons.append(
|
||||
f"Task failed at step {task_state.current_step}: {task_state.status}"
|
||||
)
|
||||
elif outcome == Outcome.SUCCESS:
|
||||
lessons.append(
|
||||
f"Successfully completed in {task_state.current_step} steps"
|
||||
)
|
||||
|
||||
return lessons
|
||||
|
||||
def _calculate_session_importance(
|
||||
self,
|
||||
task_state: TaskState | None,
|
||||
outcome: Outcome,
|
||||
actions: list[dict[str, Any]],
|
||||
) -> float:
|
||||
"""Calculate importance score for session."""
|
||||
score = 0.5 # Base score
|
||||
|
||||
# Failures are important to learn from
|
||||
if outcome == Outcome.FAILURE:
|
||||
score += 0.3
|
||||
|
||||
# Many steps means complex task
|
||||
if task_state and task_state.total_steps >= 5:
|
||||
score += 0.1
|
||||
|
||||
# Many actions means detailed reasoning
|
||||
if len(actions) >= 5:
|
||||
score += 0.1
|
||||
|
||||
return min(1.0, score)
|
||||
|
||||
def _calculate_duration(self, task_state: TaskState | None) -> float:
|
||||
"""Calculate session duration."""
|
||||
if task_state is None:
|
||||
return 0.0
|
||||
|
||||
if task_state.started_at and task_state.updated_at:
|
||||
delta = task_state.updated_at - task_state.started_at
|
||||
return delta.total_seconds()
|
||||
|
||||
return 0.0
|
||||
|
||||
# =========================================================================
|
||||
# Episodic -> Semantic Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_episodes_to_facts(
|
||||
self,
|
||||
project_id: UUID,
|
||||
since: datetime | None = None,
|
||||
limit: int | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Extract facts from episodic memories to semantic memory.
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
since: Only process episodes since this time
|
||||
limit: Maximum episodes to process
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with extraction statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
semantic = await self._get_semantic()
|
||||
|
||||
# Get episodes to process
|
||||
since_time = since or datetime.now(UTC) - timedelta(days=1)
|
||||
episodes = await episodic.get_recent(
|
||||
project_id,
|
||||
limit=limit or self._config.batch_size,
|
||||
since=since_time,
|
||||
)
|
||||
|
||||
for episode in episodes:
|
||||
result.items_processed += 1
|
||||
|
||||
try:
|
||||
# Extract facts using the extractor
|
||||
extracted_facts = self._fact_extractor.extract_from_episode(episode)
|
||||
|
||||
for extracted_fact in extracted_facts:
|
||||
if (
|
||||
extracted_fact.confidence
|
||||
< self._config.min_confidence_for_fact
|
||||
):
|
||||
result.items_skipped += 1
|
||||
continue
|
||||
|
||||
# Create fact (store_fact handles deduplication/reinforcement)
|
||||
fact_create = extracted_fact.to_fact_create(
|
||||
project_id=project_id,
|
||||
source_episode_ids=[episode.id],
|
||||
)
|
||||
|
||||
# store_fact automatically reinforces if fact already exists
|
||||
fact = await semantic.store_fact(fact_create)
|
||||
|
||||
# Check if this was a new fact or reinforced existing
|
||||
if fact.reinforcement_count == 1:
|
||||
result.items_created += 1
|
||||
else:
|
||||
result.items_updated += 1
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Episode {episode.id}: {e}")
|
||||
logger.warning(
|
||||
f"Failed to extract facts from episode {episode.id}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Consolidation failed: {e}")
|
||||
logger.exception("Failed episodic -> semantic consolidation")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episodic -> Semantic consolidation: "
|
||||
f"{result.items_processed} processed, "
|
||||
f"{result.items_created} created, "
|
||||
f"{result.items_updated} reinforced"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# =========================================================================
|
||||
# Episodic -> Procedural Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def consolidate_episodes_to_procedures(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Learn procedures from patterns in episodic memories.
|
||||
|
||||
Identifies recurring successful patterns and creates/updates
|
||||
procedures to capture them.
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
agent_type_id: Optional filter by agent type
|
||||
since: Only process episodes since this time
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with procedure statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
procedural = await self._get_procedural()
|
||||
|
||||
# Get successful episodes
|
||||
since_time = since or datetime.now(UTC) - timedelta(days=7)
|
||||
episodes = await episodic.get_by_outcome(
|
||||
project_id,
|
||||
outcome=Outcome.SUCCESS,
|
||||
limit=self._config.batch_size,
|
||||
agent_instance_id=None, # Get all agent instances
|
||||
)
|
||||
|
||||
# Group by task type
|
||||
task_groups: dict[str, list[Episode]] = {}
|
||||
for episode in episodes:
|
||||
if episode.occurred_at >= since_time:
|
||||
if episode.task_type not in task_groups:
|
||||
task_groups[episode.task_type] = []
|
||||
task_groups[episode.task_type].append(episode)
|
||||
|
||||
result.items_processed = len(episodes)
|
||||
|
||||
# Process each task type group
|
||||
for task_type, group in task_groups.items():
|
||||
if len(group) < self._config.min_episodes_for_procedure:
|
||||
result.items_skipped += len(group)
|
||||
continue
|
||||
|
||||
try:
|
||||
procedure_result = await self._learn_procedure_from_episodes(
|
||||
procedural,
|
||||
project_id,
|
||||
agent_type_id,
|
||||
task_type,
|
||||
group,
|
||||
)
|
||||
|
||||
if procedure_result == "created":
|
||||
result.items_created += 1
|
||||
elif procedure_result == "updated":
|
||||
result.items_updated += 1
|
||||
else:
|
||||
result.items_skipped += 1
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Task type '{task_type}': {e}")
|
||||
logger.warning(f"Failed to learn procedure for '{task_type}': {e}")
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Consolidation failed: {e}")
|
||||
logger.exception("Failed episodic -> procedural consolidation")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episodic -> Procedural consolidation: "
|
||||
f"{result.items_processed} processed, "
|
||||
f"{result.items_created} created, "
|
||||
f"{result.items_updated} updated"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _learn_procedure_from_episodes(
|
||||
self,
|
||||
procedural: ProceduralMemory,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None,
|
||||
task_type: str,
|
||||
episodes: list[Episode],
|
||||
) -> str:
|
||||
"""Learn or update a procedure from a set of episodes."""
|
||||
# Calculate success rate for this pattern
|
||||
success_count = sum(1 for e in episodes if e.outcome == Outcome.SUCCESS)
|
||||
total_count = len(episodes)
|
||||
success_rate = success_count / total_count if total_count > 0 else 0
|
||||
|
||||
if success_rate < self._config.min_success_rate_for_procedure:
|
||||
return "skipped"
|
||||
|
||||
# Extract common steps from episodes
|
||||
steps = self._extract_common_steps(episodes)
|
||||
|
||||
if len(steps) < self._config.min_steps_for_procedure:
|
||||
return "skipped"
|
||||
|
||||
# Check for existing procedure
|
||||
matching = await procedural.find_matching(
|
||||
context=task_type,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=1,
|
||||
)
|
||||
|
||||
if matching:
|
||||
# Update existing procedure with new success
|
||||
await procedural.record_outcome(
|
||||
matching[0].id,
|
||||
success=True,
|
||||
)
|
||||
return "updated"
|
||||
else:
|
||||
# Create new procedure
|
||||
# Note: success_count starts at 1 in record_procedure
|
||||
procedure_data = ProcedureCreate(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
name=f"Procedure for {task_type}",
|
||||
trigger_pattern=task_type,
|
||||
steps=steps,
|
||||
)
|
||||
await procedural.record_procedure(procedure_data)
|
||||
return "created"
|
||||
|
||||
def _extract_common_steps(self, episodes: list[Episode]) -> list[dict[str, Any]]:
|
||||
"""Extract common action steps from multiple episodes."""
|
||||
# Simple heuristic: take the steps from the most successful episode
|
||||
# with the most detailed actions
|
||||
|
||||
best_episode = max(
|
||||
episodes,
|
||||
key=lambda e: (
|
||||
e.outcome == Outcome.SUCCESS,
|
||||
len(e.actions),
|
||||
e.importance_score,
|
||||
),
|
||||
)
|
||||
|
||||
steps: list[dict[str, Any]] = []
|
||||
for i, action in enumerate(best_episode.actions):
|
||||
step = {
|
||||
"order": i + 1,
|
||||
"action": action.get("type", "action"),
|
||||
"description": action.get("content", str(action))[:500],
|
||||
"parameters": action,
|
||||
}
|
||||
steps.append(step)
|
||||
|
||||
return steps
|
||||
|
||||
# =========================================================================
|
||||
# Memory Pruning
|
||||
# =========================================================================
|
||||
|
||||
async def prune_old_episodes(
|
||||
self,
|
||||
project_id: UUID,
|
||||
max_age_days: int | None = None,
|
||||
min_importance: float | None = None,
|
||||
) -> ConsolidationResult:
|
||||
"""
|
||||
Prune old, low-value episodes.
|
||||
|
||||
Args:
|
||||
project_id: Project to prune
|
||||
max_age_days: Maximum age in days (default from config)
|
||||
min_importance: Minimum importance to keep (default from config)
|
||||
|
||||
Returns:
|
||||
ConsolidationResult with pruning statistics
|
||||
"""
|
||||
start_time = datetime.now(UTC)
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
)
|
||||
|
||||
max_age = max_age_days or self._config.max_episode_age_days
|
||||
min_imp = min_importance or self._config.min_importance_to_keep
|
||||
cutoff_date = datetime.now(UTC) - timedelta(days=max_age)
|
||||
|
||||
try:
|
||||
episodic = await self._get_episodic()
|
||||
|
||||
# Get old episodes
|
||||
# Note: In production, this would use a more efficient query
|
||||
all_episodes = await episodic.get_recent(
|
||||
project_id,
|
||||
limit=self._config.batch_size * 10,
|
||||
since=cutoff_date - timedelta(days=365), # Search past year
|
||||
)
|
||||
|
||||
for episode in all_episodes:
|
||||
result.items_processed += 1
|
||||
|
||||
# Check if should be pruned
|
||||
if not self._should_prune_episode(episode, cutoff_date, min_imp):
|
||||
result.items_skipped += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted = await episodic.delete(episode.id)
|
||||
if deleted:
|
||||
result.items_pruned += 1
|
||||
else:
|
||||
result.items_skipped += 1
|
||||
except Exception as e:
|
||||
result.errors.append(f"Episode {episode.id}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Pruning failed: {e}")
|
||||
logger.exception("Failed episode pruning")
|
||||
|
||||
result.duration_seconds = (datetime.now(UTC) - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Episode pruning: {result.items_processed} processed, "
|
||||
f"{result.items_pruned} pruned"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _should_prune_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
cutoff_date: datetime,
|
||||
min_importance: float,
|
||||
) -> bool:
|
||||
"""Determine if an episode should be pruned."""
|
||||
# Keep recent episodes
|
||||
if episode.occurred_at >= cutoff_date:
|
||||
return False
|
||||
|
||||
# Keep failures if configured
|
||||
if self._config.keep_all_failures and episode.outcome == Outcome.FAILURE:
|
||||
return False
|
||||
|
||||
# Keep episodes with lessons if configured
|
||||
if self._config.keep_all_with_lessons and episode.lessons_learned:
|
||||
return False
|
||||
|
||||
# Keep high-importance episodes
|
||||
if episode.importance_score >= min_importance:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Nightly Consolidation
|
||||
# =========================================================================
|
||||
|
||||
async def run_nightly_consolidation(
|
||||
self,
|
||||
project_id: UUID,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> NightlyConsolidationResult:
|
||||
"""
|
||||
Run full nightly consolidation workflow.
|
||||
|
||||
This includes:
|
||||
1. Extract facts from recent episodes
|
||||
2. Learn procedures from successful patterns
|
||||
3. Prune old, low-value memories
|
||||
|
||||
Args:
|
||||
project_id: Project to consolidate
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
NightlyConsolidationResult with all outcomes
|
||||
"""
|
||||
result = NightlyConsolidationResult(started_at=datetime.now(UTC))
|
||||
|
||||
logger.info(f"Starting nightly consolidation for project {project_id}")
|
||||
|
||||
try:
|
||||
# Step 1: Episodic -> Semantic (last 24 hours)
|
||||
since_yesterday = datetime.now(UTC) - timedelta(days=1)
|
||||
result.episodic_to_semantic = await self.consolidate_episodes_to_facts(
|
||||
project_id=project_id,
|
||||
since=since_yesterday,
|
||||
)
|
||||
result.total_facts_created = result.episodic_to_semantic.items_created
|
||||
|
||||
# Step 2: Episodic -> Procedural (last 7 days)
|
||||
since_week = datetime.now(UTC) - timedelta(days=7)
|
||||
result.episodic_to_procedural = (
|
||||
await self.consolidate_episodes_to_procedures(
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
since=since_week,
|
||||
)
|
||||
)
|
||||
result.total_procedures_created = (
|
||||
result.episodic_to_procedural.items_created
|
||||
)
|
||||
|
||||
# Step 3: Prune old memories
|
||||
result.pruning = await self.prune_old_episodes(project_id=project_id)
|
||||
result.total_pruned = result.pruning.items_pruned
|
||||
|
||||
# Calculate totals
|
||||
result.total_episodes_processed = (
|
||||
result.episodic_to_semantic.items_processed
|
||||
if result.episodic_to_semantic
|
||||
else 0
|
||||
) + (
|
||||
result.episodic_to_procedural.items_processed
|
||||
if result.episodic_to_procedural
|
||||
else 0
|
||||
)
|
||||
|
||||
# Collect all errors
|
||||
if result.episodic_to_semantic and result.episodic_to_semantic.errors:
|
||||
result.errors.extend(result.episodic_to_semantic.errors)
|
||||
if result.episodic_to_procedural and result.episodic_to_procedural.errors:
|
||||
result.errors.extend(result.episodic_to_procedural.errors)
|
||||
if result.pruning and result.pruning.errors:
|
||||
result.errors.extend(result.pruning.errors)
|
||||
|
||||
except Exception as e:
|
||||
result.errors.append(f"Nightly consolidation failed: {e}")
|
||||
logger.exception("Nightly consolidation failed")
|
||||
|
||||
result.completed_at = datetime.now(UTC)
|
||||
|
||||
duration = (result.completed_at - result.started_at).total_seconds()
|
||||
logger.info(
|
||||
f"Nightly consolidation completed in {duration:.1f}s: "
|
||||
f"{result.total_facts_created} facts, "
|
||||
f"{result.total_procedures_created} procedures, "
|
||||
f"{result.total_pruned} pruned"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_consolidation_service: MemoryConsolidationService | None = None
|
||||
|
||||
|
||||
async def get_consolidation_service(
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
) -> MemoryConsolidationService:
|
||||
"""
|
||||
Get or create the memory consolidation service.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
config: Optional configuration
|
||||
|
||||
Returns:
|
||||
MemoryConsolidationService instance
|
||||
"""
|
||||
global _consolidation_service
|
||||
if _consolidation_service is None:
|
||||
_consolidation_service = MemoryConsolidationService(
|
||||
session=session, config=config
|
||||
)
|
||||
return _consolidation_service
|
||||
@@ -10,14 +10,16 @@ Modules:
|
||||
sync: Issue synchronization tasks (incremental/full sync, webhooks)
|
||||
workflow: Workflow state management tasks
|
||||
cost: Cost tracking and budget monitoring tasks
|
||||
memory_consolidation: Memory consolidation tasks
|
||||
"""
|
||||
|
||||
from app.tasks import agent, cost, git, sync, workflow
|
||||
from app.tasks import agent, cost, git, memory_consolidation, sync, workflow
|
||||
|
||||
__all__ = [
|
||||
"agent",
|
||||
"cost",
|
||||
"git",
|
||||
"memory_consolidation",
|
||||
"sync",
|
||||
"workflow",
|
||||
]
|
||||
|
||||
234
backend/app/tasks/memory_consolidation.py
Normal file
234
backend/app/tasks/memory_consolidation.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# app/tasks/memory_consolidation.py
|
||||
"""
|
||||
Memory consolidation Celery tasks.
|
||||
|
||||
Handles scheduled and on-demand memory consolidation:
|
||||
- Session consolidation (on session end)
|
||||
- Nightly consolidation (scheduled)
|
||||
- On-demand project consolidation
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_session",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_session(
|
||||
self,
|
||||
project_id: str,
|
||||
session_id: str,
|
||||
task_type: str = "session_task",
|
||||
agent_instance_id: str | None = None,
|
||||
agent_type_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Consolidate a session's working memory to episodic memory.
|
||||
|
||||
This task is triggered when an agent session ends to transfer
|
||||
relevant session data into persistent episodic memory.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
session_id: Session identifier
|
||||
task_type: Type of task performed
|
||||
agent_instance_id: Optional agent instance UUID
|
||||
agent_type_id: Optional agent type UUID
|
||||
|
||||
Returns:
|
||||
dict with consolidation results
|
||||
"""
|
||||
logger.info(f"Consolidating session {session_id} for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# This will involve:
|
||||
# 1. Getting database session from async context
|
||||
# 2. Loading working memory for session
|
||||
# 3. Calling consolidation service
|
||||
# 4. Returning results
|
||||
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"session_id": session_id,
|
||||
"episode_created": False,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.run_nightly_consolidation",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def run_nightly_consolidation(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_type_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run nightly memory consolidation for a project.
|
||||
|
||||
This task performs the full consolidation workflow:
|
||||
1. Extract facts from recent episodes to semantic memory
|
||||
2. Learn procedures from successful episode patterns
|
||||
3. Prune old, low-value memories
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project to consolidate
|
||||
agent_type_id: Optional agent type to filter by
|
||||
|
||||
Returns:
|
||||
dict with consolidation results
|
||||
"""
|
||||
logger.info(f"Running nightly consolidation for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# This will involve:
|
||||
# 1. Getting database session from async context
|
||||
# 2. Creating consolidation service instance
|
||||
# 3. Running run_nightly_consolidation
|
||||
# 4. Returning results
|
||||
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"total_facts_created": 0,
|
||||
"total_procedures_created": 0,
|
||||
"total_pruned": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_episodes_to_facts",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_episodes_to_facts(
|
||||
self,
|
||||
project_id: str,
|
||||
since_hours: int = 24,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract facts from episodic memories.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
since_hours: Process episodes from last N hours
|
||||
limit: Maximum episodes to process
|
||||
|
||||
Returns:
|
||||
dict with extraction results
|
||||
"""
|
||||
logger.info(f"Consolidating episodes to facts for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_processed": 0,
|
||||
"items_created": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.consolidate_episodes_to_procedures",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def consolidate_episodes_to_procedures(
|
||||
self,
|
||||
project_id: str,
|
||||
agent_type_id: str | None = None,
|
||||
since_days: int = 7,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Learn procedures from episodic patterns.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
agent_type_id: Optional agent type filter
|
||||
since_days: Process episodes from last N days
|
||||
|
||||
Returns:
|
||||
dict with procedure learning results
|
||||
"""
|
||||
logger.info(f"Consolidating episodes to procedures for project {project_id}")
|
||||
|
||||
# TODO: Implement actual consolidation
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_processed": 0,
|
||||
"items_created": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="app.tasks.memory_consolidation.prune_old_memories",
|
||||
autoretry_for=(Exception,),
|
||||
retry_backoff=True,
|
||||
retry_kwargs={"max_retries": 3},
|
||||
)
|
||||
def prune_old_memories(
|
||||
self,
|
||||
project_id: str,
|
||||
max_age_days: int = 90,
|
||||
min_importance: float = 0.2,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Prune old, low-value memories.
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
max_age_days: Maximum age in days
|
||||
min_importance: Minimum importance to keep
|
||||
|
||||
Returns:
|
||||
dict with pruning results
|
||||
"""
|
||||
logger.info(f"Pruning old memories for project {project_id}")
|
||||
|
||||
# TODO: Implement actual pruning
|
||||
# Placeholder implementation
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"items_pruned": 0,
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Celery Beat Schedule Configuration
|
||||
# =========================================================================
|
||||
|
||||
# This would typically be configured in celery_app.py or a separate config file
|
||||
# Example schedule for nightly consolidation:
|
||||
#
|
||||
# app.conf.beat_schedule = {
|
||||
# 'nightly-memory-consolidation': {
|
||||
# 'task': 'app.tasks.memory_consolidation.run_nightly_consolidation',
|
||||
# 'schedule': crontab(hour=2, minute=0), # 2 AM daily
|
||||
# 'args': (None,), # Will process all projects
|
||||
# },
|
||||
# }
|
||||
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/consolidation/__init__.py
|
||||
"""Tests for memory consolidation."""
|
||||
736
backend/tests/unit/services/memory/consolidation/test_service.py
Normal file
736
backend/tests/unit/services/memory/consolidation/test_service.py
Normal file
@@ -0,0 +1,736 @@
|
||||
# tests/unit/services/memory/consolidation/test_service.py
|
||||
"""Unit tests for memory consolidation service."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.consolidation.service import (
|
||||
ConsolidationConfig,
|
||||
ConsolidationResult,
|
||||
MemoryConsolidationService,
|
||||
NightlyConsolidationResult,
|
||||
SessionConsolidationResult,
|
||||
)
|
||||
from app.services.memory.types import Episode, Outcome, TaskState
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
"""Get current UTC time."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
def make_episode(
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
occurred_at: datetime | None = None,
|
||||
task_type: str = "test_task",
|
||||
lessons_learned: list[str] | None = None,
|
||||
importance_score: float = 0.5,
|
||||
actions: list[dict] | None = None,
|
||||
) -> Episode:
|
||||
"""Create a test episode."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
agent_type_id=uuid4(),
|
||||
session_id="test-session",
|
||||
task_type=task_type,
|
||||
task_description="Test task description",
|
||||
actions=actions or [{"action": "test"}],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details="Test outcome",
|
||||
duration_seconds=10.0,
|
||||
tokens_used=100,
|
||||
lessons_learned=lessons_learned or [],
|
||||
importance_score=importance_score,
|
||||
embedding=None,
|
||||
occurred_at=occurred_at or _utcnow(),
|
||||
created_at=_utcnow(),
|
||||
updated_at=_utcnow(),
|
||||
)
|
||||
|
||||
|
||||
def make_task_state(
|
||||
current_step: int = 5,
|
||||
total_steps: int = 10,
|
||||
progress_percent: float = 50.0,
|
||||
status: str = "in_progress",
|
||||
description: str = "Test Task",
|
||||
) -> TaskState:
|
||||
"""Create a test task state."""
|
||||
now = _utcnow()
|
||||
return TaskState(
|
||||
task_id="test-task-id",
|
||||
task_type="test_task",
|
||||
description=description,
|
||||
current_step=current_step,
|
||||
total_steps=total_steps,
|
||||
status=status,
|
||||
progress_percent=progress_percent,
|
||||
started_at=now - timedelta(hours=1),
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
class TestConsolidationConfig:
|
||||
"""Tests for ConsolidationConfig."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default configuration values."""
|
||||
config = ConsolidationConfig()
|
||||
|
||||
assert config.min_steps_for_episode == 2
|
||||
assert config.min_duration_seconds == 5.0
|
||||
assert config.min_confidence_for_fact == 0.6
|
||||
assert config.max_facts_per_episode == 10
|
||||
assert config.min_episodes_for_procedure == 3
|
||||
assert config.max_episode_age_days == 90
|
||||
assert config.batch_size == 100
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Test custom configuration values."""
|
||||
config = ConsolidationConfig(
|
||||
min_steps_for_episode=5,
|
||||
batch_size=50,
|
||||
)
|
||||
|
||||
assert config.min_steps_for_episode == 5
|
||||
assert config.batch_size == 50
|
||||
|
||||
|
||||
class TestConsolidationResult:
|
||||
"""Tests for ConsolidationResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a consolidation result."""
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
items_processed=10,
|
||||
items_created=5,
|
||||
)
|
||||
|
||||
assert result.source_type == "episodic"
|
||||
assert result.target_type == "semantic"
|
||||
assert result.items_processed == 10
|
||||
assert result.items_created == 5
|
||||
assert result.items_skipped == 0
|
||||
assert result.errors == []
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting to dictionary."""
|
||||
result = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
items_processed=10,
|
||||
items_created=5,
|
||||
errors=["test error"],
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
|
||||
assert d["source_type"] == "episodic"
|
||||
assert d["target_type"] == "semantic"
|
||||
assert d["items_processed"] == 10
|
||||
assert d["items_created"] == 5
|
||||
assert "test error" in d["errors"]
|
||||
|
||||
|
||||
class TestSessionConsolidationResult:
|
||||
"""Tests for SessionConsolidationResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a session consolidation result."""
|
||||
result = SessionConsolidationResult(
|
||||
session_id="test-session",
|
||||
episode_created=True,
|
||||
episode_id=uuid4(),
|
||||
scratchpad_entries=5,
|
||||
)
|
||||
|
||||
assert result.session_id == "test-session"
|
||||
assert result.episode_created is True
|
||||
assert result.episode_id is not None
|
||||
|
||||
|
||||
class TestNightlyConsolidationResult:
|
||||
"""Tests for NightlyConsolidationResult."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test creating a nightly consolidation result."""
|
||||
result = NightlyConsolidationResult(
|
||||
started_at=_utcnow(),
|
||||
)
|
||||
|
||||
assert result.started_at is not None
|
||||
assert result.completed_at is None
|
||||
assert result.total_episodes_processed == 0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting to dictionary."""
|
||||
result = NightlyConsolidationResult(
|
||||
started_at=_utcnow(),
|
||||
completed_at=_utcnow(),
|
||||
total_facts_created=5,
|
||||
total_procedures_created=2,
|
||||
)
|
||||
|
||||
d = result.to_dict()
|
||||
|
||||
assert "started_at" in d
|
||||
assert "completed_at" in d
|
||||
assert d["total_facts_created"] == 5
|
||||
assert d["total_procedures_created"] == 2
|
||||
|
||||
|
||||
class TestMemoryConsolidationService:
|
||||
"""Tests for MemoryConsolidationService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session: AsyncMock) -> MemoryConsolidationService:
|
||||
"""Create a consolidation service with mocked dependencies."""
|
||||
return MemoryConsolidationService(
|
||||
session=mock_session,
|
||||
config=ConsolidationConfig(),
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Session Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_session_insufficient_steps(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test session not consolidated when insufficient steps."""
|
||||
mock_working_memory = AsyncMock()
|
||||
task_state = make_task_state(current_step=1) # Less than min_steps_for_episode
|
||||
mock_working_memory.get_task_state.return_value = task_state
|
||||
|
||||
result = await service.consolidate_session(
|
||||
working_memory=mock_working_memory,
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert result.episode_created is False
|
||||
assert result.episode_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_session_no_task_state(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test session not consolidated when no task state."""
|
||||
mock_working_memory = AsyncMock()
|
||||
mock_working_memory.get_task_state.return_value = None
|
||||
|
||||
result = await service.consolidate_session(
|
||||
working_memory=mock_working_memory,
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert result.episode_created is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_session_success(
|
||||
self, service: MemoryConsolidationService, mock_session: AsyncMock
|
||||
) -> None:
|
||||
"""Test successful session consolidation."""
|
||||
mock_working_memory = AsyncMock()
|
||||
task_state = make_task_state(
|
||||
current_step=5,
|
||||
progress_percent=100.0,
|
||||
status="complete",
|
||||
)
|
||||
mock_working_memory.get_task_state.return_value = task_state
|
||||
mock_working_memory.get_scratchpad.return_value = ["step1", "step2"]
|
||||
mock_working_memory.get_all.return_value = {"key1": "value1"}
|
||||
|
||||
# Mock episodic memory
|
||||
mock_episode = make_episode()
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode.return_value = mock_episode
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.consolidate_session(
|
||||
working_memory=mock_working_memory,
|
||||
project_id=uuid4(),
|
||||
session_id="test-session",
|
||||
)
|
||||
|
||||
assert result.episode_created is True
|
||||
assert result.episode_id == mock_episode.id
|
||||
assert result.scratchpad_entries == 2
|
||||
|
||||
# =========================================================================
|
||||
# Outcome Determination Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_determine_session_outcome_success(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination for successful session."""
|
||||
task_state = make_task_state(status="complete", progress_percent=100.0)
|
||||
outcome = service._determine_session_outcome(task_state)
|
||||
assert outcome == Outcome.SUCCESS
|
||||
|
||||
def test_determine_session_outcome_failure(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination for failed session."""
|
||||
task_state = make_task_state(status="error", progress_percent=25.0)
|
||||
outcome = service._determine_session_outcome(task_state)
|
||||
assert outcome == Outcome.FAILURE
|
||||
|
||||
def test_determine_session_outcome_partial(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination for partial session."""
|
||||
task_state = make_task_state(status="stopped", progress_percent=60.0)
|
||||
outcome = service._determine_session_outcome(task_state)
|
||||
assert outcome == Outcome.PARTIAL
|
||||
|
||||
def test_determine_session_outcome_none(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test outcome determination with no task state."""
|
||||
outcome = service._determine_session_outcome(None)
|
||||
assert outcome == Outcome.PARTIAL
|
||||
|
||||
# =========================================================================
|
||||
# Action Building Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_build_actions_from_session(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test building actions from session data."""
|
||||
scratchpad = ["thought 1", "thought 2"]
|
||||
variables = {"var1": "value1"}
|
||||
task_state = make_task_state()
|
||||
|
||||
actions = service._build_actions_from_session(scratchpad, variables, task_state)
|
||||
|
||||
assert len(actions) == 3 # 2 scratchpad + 1 final state
|
||||
assert actions[0]["type"] == "reasoning"
|
||||
assert actions[2]["type"] == "final_state"
|
||||
|
||||
def test_build_context_summary(self, service: MemoryConsolidationService) -> None:
|
||||
"""Test building context summary."""
|
||||
task_state = make_task_state(
|
||||
description="Test Task",
|
||||
progress_percent=75.0,
|
||||
)
|
||||
variables = {"key": "value"}
|
||||
|
||||
summary = service._build_context_summary(task_state, variables)
|
||||
|
||||
assert "Test Task" in summary
|
||||
assert "75.0%" in summary
|
||||
|
||||
# =========================================================================
|
||||
# Importance Calculation Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_calculate_session_importance_base(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test base importance calculation."""
|
||||
task_state = make_task_state(total_steps=3) # Below threshold
|
||||
importance = service._calculate_session_importance(
|
||||
task_state, Outcome.SUCCESS, []
|
||||
)
|
||||
|
||||
assert importance == 0.5 # Base score
|
||||
|
||||
def test_calculate_session_importance_failure(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test importance boost for failures."""
|
||||
task_state = make_task_state(total_steps=3) # Below threshold
|
||||
importance = service._calculate_session_importance(
|
||||
task_state, Outcome.FAILURE, []
|
||||
)
|
||||
|
||||
assert importance == 0.8 # Base (0.5) + failure boost (0.3)
|
||||
|
||||
def test_calculate_session_importance_complex(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test importance for complex session."""
|
||||
task_state = make_task_state(total_steps=10)
|
||||
actions = [{"step": i} for i in range(6)]
|
||||
importance = service._calculate_session_importance(
|
||||
task_state, Outcome.SUCCESS, actions
|
||||
)
|
||||
|
||||
# Base (0.5) + many steps (0.1) + many actions (0.1)
|
||||
assert importance == 0.7
|
||||
|
||||
# =========================================================================
|
||||
# Episode to Fact Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_facts_empty(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test consolidation with no episodes."""
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent.return_value = []
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.consolidate_episodes_to_facts(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 0
|
||||
assert result.items_created == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_facts_success(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test successful fact extraction."""
|
||||
episode = make_episode(
|
||||
lessons_learned=["Always check return values"],
|
||||
)
|
||||
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.reinforcement_count = 1 # New fact
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic,
|
||||
patch.object(
|
||||
service, "_get_semantic", new_callable=AsyncMock
|
||||
) as mock_get_semantic,
|
||||
):
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent.return_value = [episode]
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.store_fact.return_value = mock_fact
|
||||
mock_get_semantic.return_value = mock_semantic
|
||||
|
||||
result = await service.consolidate_episodes_to_facts(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 1
|
||||
# At least one fact should be created from lesson
|
||||
assert result.items_created >= 0
|
||||
|
||||
# =========================================================================
|
||||
# Episode to Procedure Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_procedures_insufficient(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test consolidation with insufficient episodes."""
|
||||
# Only 1 episode - less than min_episodes_for_procedure (3)
|
||||
episode = make_episode()
|
||||
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_by_outcome.return_value = [episode]
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.consolidate_episodes_to_procedures(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 1
|
||||
assert result.items_created == 0
|
||||
assert result.items_skipped == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_episodes_to_procedures_success(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test successful procedure creation."""
|
||||
# Create enough episodes for a procedure
|
||||
episodes = [
|
||||
make_episode(
|
||||
task_type="deploy",
|
||||
actions=[{"type": "step1"}, {"type": "step2"}, {"type": "step3"}],
|
||||
)
|
||||
for _ in range(5)
|
||||
]
|
||||
|
||||
mock_procedure = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic,
|
||||
patch.object(
|
||||
service, "_get_procedural", new_callable=AsyncMock
|
||||
) as mock_get_procedural,
|
||||
):
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_by_outcome.return_value = episodes
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching.return_value = [] # No existing procedure
|
||||
mock_procedural.record_procedure.return_value = mock_procedure
|
||||
mock_get_procedural.return_value = mock_procedural
|
||||
|
||||
result = await service.consolidate_episodes_to_procedures(
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert result.items_processed == 5
|
||||
assert result.items_created == 1
|
||||
|
||||
# =========================================================================
|
||||
# Common Steps Extraction Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_extract_common_steps(self, service: MemoryConsolidationService) -> None:
|
||||
"""Test extracting steps from episodes."""
|
||||
episodes = [
|
||||
make_episode(
|
||||
outcome=Outcome.SUCCESS,
|
||||
importance_score=0.8,
|
||||
actions=[
|
||||
{"type": "step1", "content": "First step"},
|
||||
{"type": "step2", "content": "Second step"},
|
||||
],
|
||||
),
|
||||
make_episode(
|
||||
outcome=Outcome.SUCCESS,
|
||||
importance_score=0.5,
|
||||
actions=[{"type": "simple"}],
|
||||
),
|
||||
]
|
||||
|
||||
steps = service._extract_common_steps(episodes)
|
||||
|
||||
assert len(steps) == 2
|
||||
assert steps[0]["order"] == 1
|
||||
assert steps[0]["action"] == "step1"
|
||||
|
||||
# =========================================================================
|
||||
# Pruning Tests
|
||||
# =========================================================================
|
||||
|
||||
def test_should_prune_episode_old_low_importance(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test pruning old, low-importance episode."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.1,
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
assert should_prune is True
|
||||
|
||||
def test_should_prune_episode_recent(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning recent episode."""
|
||||
recent_date = _utcnow() - timedelta(days=30)
|
||||
episode = make_episode(
|
||||
occurred_at=recent_date,
|
||||
importance_score=0.1,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
assert should_prune is False
|
||||
|
||||
def test_should_prune_episode_failure_protected(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning failure (with keep_all_failures=True)."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.1,
|
||||
outcome=Outcome.FAILURE,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
# Config has keep_all_failures=True by default
|
||||
assert should_prune is False
|
||||
|
||||
def test_should_prune_episode_with_lessons_protected(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning episode with lessons."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.1,
|
||||
lessons_learned=["Important lesson"],
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
# Config has keep_all_with_lessons=True by default
|
||||
assert should_prune is False
|
||||
|
||||
def test_should_prune_episode_high_importance_protected(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test not pruning high importance episode."""
|
||||
old_date = _utcnow() - timedelta(days=100)
|
||||
episode = make_episode(
|
||||
occurred_at=old_date,
|
||||
importance_score=0.8,
|
||||
)
|
||||
cutoff = _utcnow() - timedelta(days=90)
|
||||
|
||||
should_prune = service._should_prune_episode(episode, cutoff, 0.2)
|
||||
|
||||
assert should_prune is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prune_old_episodes(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test episode pruning."""
|
||||
old_episode = make_episode(
|
||||
occurred_at=_utcnow() - timedelta(days=100),
|
||||
importance_score=0.1,
|
||||
outcome=Outcome.SUCCESS,
|
||||
lessons_learned=[],
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
service, "_get_episodic", new_callable=AsyncMock
|
||||
) as mock_get_episodic:
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent.return_value = [old_episode]
|
||||
mock_episodic.delete.return_value = True
|
||||
mock_get_episodic.return_value = mock_episodic
|
||||
|
||||
result = await service.prune_old_episodes(project_id=uuid4())
|
||||
|
||||
assert result.items_processed == 1
|
||||
assert result.items_pruned == 1
|
||||
|
||||
# =========================================================================
|
||||
# Nightly Consolidation Tests
|
||||
# =========================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_nightly_consolidation(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test nightly consolidation workflow."""
|
||||
with (
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_facts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_facts,
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_procedures",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_procedures,
|
||||
patch.object(
|
||||
service,
|
||||
"prune_old_episodes",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_prune,
|
||||
):
|
||||
mock_facts.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
items_processed=10,
|
||||
items_created=5,
|
||||
)
|
||||
mock_procedures.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
items_processed=10,
|
||||
items_created=2,
|
||||
)
|
||||
mock_prune.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
items_pruned=3,
|
||||
)
|
||||
|
||||
result = await service.run_nightly_consolidation(project_id=uuid4())
|
||||
|
||||
assert result.completed_at is not None
|
||||
assert result.total_facts_created == 5
|
||||
assert result.total_procedures_created == 2
|
||||
assert result.total_pruned == 3
|
||||
assert result.total_episodes_processed == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_nightly_consolidation_with_errors(
|
||||
self, service: MemoryConsolidationService
|
||||
) -> None:
|
||||
"""Test nightly consolidation handles errors."""
|
||||
with (
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_facts",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_facts,
|
||||
patch.object(
|
||||
service,
|
||||
"consolidate_episodes_to_procedures",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_procedures,
|
||||
patch.object(
|
||||
service,
|
||||
"prune_old_episodes",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_prune,
|
||||
):
|
||||
mock_facts.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="semantic",
|
||||
errors=["fact error"],
|
||||
)
|
||||
mock_procedures.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="procedural",
|
||||
)
|
||||
mock_prune.return_value = ConsolidationResult(
|
||||
source_type="episodic",
|
||||
target_type="pruned",
|
||||
)
|
||||
|
||||
result = await service.run_nightly_consolidation(project_id=uuid4())
|
||||
|
||||
assert "fact error" in result.errors
|
||||
Reference in New Issue
Block a user