feat(memory): integrate memory system with context engine (#97)
## Changes ### New Context Type - Add MEMORY to ContextType enum for agent memory context - Create MemoryContext class with subtypes (working, episodic, semantic, procedural) - Factory methods: from_working_memory, from_episodic_memory, from_semantic_memory, from_procedural_memory ### Memory Context Source - MemoryContextSource service fetches relevant memories for context assembly - Configurable fetch limits per memory type - Parallel fetching from all memory types ### Agent Lifecycle Hooks - AgentLifecycleManager handles spawn, pause, resume, terminate events - spawn: Initialize working memory with optional initial state - pause: Create checkpoint of working memory - resume: Restore from checkpoint - terminate: Consolidate working memory to episodic memory - LifecycleHooks for custom extension points ### Context Engine Integration - Add memory_query parameter to assemble_context() - Add session_id and agent_type_id for memory scoping - Memory budget allocation (15% by default) - set_memory_source() for runtime configuration ### Tests - 48 new tests for MemoryContext, MemoryContextSource, and lifecycle hooks - All 108 memory-related tests passing - mypy and ruff checks passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
471
backend/tests/unit/services/memory/integration/test_lifecycle.py
Normal file
471
backend/tests/unit/services/memory/integration/test_lifecycle.py
Normal file
@@ -0,0 +1,471 @@
|
||||
# tests/unit/services/memory/integration/test_lifecycle.py
|
||||
"""Tests for Agent Lifecycle Hooks."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.integration.lifecycle import (
|
||||
AgentLifecycleManager,
|
||||
LifecycleEvent,
|
||||
LifecycleHooks,
|
||||
LifecycleResult,
|
||||
get_lifecycle_manager,
|
||||
)
|
||||
from app.services.memory.types import Outcome
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session() -> MagicMock:
|
||||
"""Create mock database session."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lifecycle_hooks() -> LifecycleHooks:
|
||||
"""Create lifecycle hooks instance."""
|
||||
return LifecycleHooks()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lifecycle_manager(mock_session: MagicMock) -> AgentLifecycleManager:
|
||||
"""Create lifecycle manager instance."""
|
||||
return AgentLifecycleManager(session=mock_session)
|
||||
|
||||
|
||||
class TestLifecycleEvent:
|
||||
"""Tests for LifecycleEvent dataclass."""
|
||||
|
||||
def test_creates_event(self) -> None:
|
||||
"""Should create event with required fields."""
|
||||
project_id = uuid4()
|
||||
agent_id = uuid4()
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_id,
|
||||
)
|
||||
|
||||
assert event.event_type == "spawn"
|
||||
assert event.project_id == project_id
|
||||
assert event.agent_instance_id == agent_id
|
||||
assert event.timestamp is not None
|
||||
assert event.metadata == {}
|
||||
|
||||
def test_with_optional_fields(self) -> None:
|
||||
"""Should include optional fields."""
|
||||
event = LifecycleEvent(
|
||||
event_type="terminate",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
metadata={"reason": "completed"},
|
||||
)
|
||||
|
||||
assert event.session_id == "sess-123"
|
||||
assert event.metadata["reason"] == "completed"
|
||||
|
||||
|
||||
class TestLifecycleResult:
|
||||
"""Tests for LifecycleResult dataclass."""
|
||||
|
||||
def test_success_result(self) -> None:
|
||||
"""Should create success result."""
|
||||
result = LifecycleResult(
|
||||
success=True,
|
||||
event_type="spawn",
|
||||
message="Agent spawned",
|
||||
data={"session_id": "sess-123"},
|
||||
duration_ms=10.5,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "spawn"
|
||||
assert result.data["session_id"] == "sess-123"
|
||||
|
||||
def test_failure_result(self) -> None:
|
||||
"""Should create failure result."""
|
||||
result = LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message="Checkpoint not found",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.message == "Checkpoint not found"
|
||||
|
||||
|
||||
class TestLifecycleHooks:
|
||||
"""Tests for LifecycleHooks class."""
|
||||
|
||||
def test_register_spawn_hook(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should register spawn hook."""
|
||||
async def my_hook(event: LifecycleEvent) -> None:
|
||||
pass
|
||||
|
||||
result = lifecycle_hooks.on_spawn(my_hook)
|
||||
|
||||
assert result is my_hook
|
||||
assert my_hook in lifecycle_hooks._spawn_hooks
|
||||
|
||||
def test_register_all_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should register hooks for all event types."""
|
||||
hooks = [
|
||||
lifecycle_hooks.on_spawn(AsyncMock()),
|
||||
lifecycle_hooks.on_pause(AsyncMock()),
|
||||
lifecycle_hooks.on_resume(AsyncMock()),
|
||||
lifecycle_hooks.on_terminate(AsyncMock()),
|
||||
]
|
||||
|
||||
assert len(lifecycle_hooks._spawn_hooks) == 1
|
||||
assert len(lifecycle_hooks._pause_hooks) == 1
|
||||
assert len(lifecycle_hooks._resume_hooks) == 1
|
||||
assert len(lifecycle_hooks._terminate_hooks) == 1
|
||||
|
||||
async def test_run_spawn_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should run all spawn hooks."""
|
||||
hook1 = AsyncMock()
|
||||
hook2 = AsyncMock()
|
||||
lifecycle_hooks.on_spawn(hook1)
|
||||
lifecycle_hooks.on_spawn(hook2)
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
)
|
||||
|
||||
await lifecycle_hooks.run_spawn_hooks(event)
|
||||
|
||||
hook1.assert_called_once_with(event)
|
||||
hook2.assert_called_once_with(event)
|
||||
|
||||
async def test_hook_failure_doesnt_stop_others(
|
||||
self, lifecycle_hooks: LifecycleHooks
|
||||
) -> None:
|
||||
"""Hook failure should not stop other hooks from running."""
|
||||
hook1 = AsyncMock(side_effect=ValueError("Oops"))
|
||||
hook2 = AsyncMock()
|
||||
lifecycle_hooks.on_pause(hook1)
|
||||
lifecycle_hooks.on_pause(hook2)
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="pause",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
)
|
||||
|
||||
await lifecycle_hooks.run_pause_hooks(event)
|
||||
|
||||
# hook2 should still be called even though hook1 failed
|
||||
hook2.assert_called_once()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerSpawn:
|
||||
"""Tests for AgentLifecycleManager.spawn."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_creates_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should create working memory for session."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "spawn"
|
||||
mock_working_cls.for_session.assert_called_once()
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_with_initial_state(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should populate initial state."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
initial_state={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["initial_items"] == 2
|
||||
assert mock_working.set.call_count == 2
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_runs_hooks(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should run registered hooks."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
hook = AsyncMock()
|
||||
lifecycle_manager.hooks.on_spawn(hook)
|
||||
|
||||
await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
hook.assert_called_once()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerPause:
|
||||
"""Tests for AgentLifecycleManager.pause."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_pause_creates_checkpoint(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Pause should create checkpoint of working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value={"data": "test"})
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.pause(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="ckpt-001",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "pause"
|
||||
assert result.data["checkpoint_id"] == "ckpt-001"
|
||||
assert result.data["items_saved"] == 2
|
||||
|
||||
# Should save checkpoint with state
|
||||
mock_working.set.assert_called_once()
|
||||
call_args = mock_working.set.call_args
|
||||
# Check positional arg (first arg is key)
|
||||
assert "__checkpoint__ckpt-001" in call_args[0][0]
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_pause_generates_checkpoint_id(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Pause should generate checkpoint ID if not provided."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=[])
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.pause(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert "checkpoint_id" in result.data
|
||||
assert result.data["checkpoint_id"].startswith("checkpoint_")
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerResume:
|
||||
"""Tests for AgentLifecycleManager.resume."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_resume_restores_checkpoint(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Resume should restore working memory from checkpoint."""
|
||||
checkpoint_data = {
|
||||
"state": {"key1": "value1", "key2": "value2"},
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"keys_count": 2,
|
||||
}
|
||||
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=[])
|
||||
mock_working.get = AsyncMock(return_value=checkpoint_data)
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.resume(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="ckpt-001",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "resume"
|
||||
assert result.data["items_restored"] == 2
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_resume_checkpoint_not_found(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Resume should fail if checkpoint not found."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.get = AsyncMock(return_value=None)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.resume(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="nonexistent",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.message.lower()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerTerminate:
|
||||
"""Tests for AgentLifecycleManager.terminate."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.EpisodicMemory")
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_terminate_consolidates_to_episodic(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
mock_episodic_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Terminate should consolidate working memory to episodic."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await lifecycle_manager.terminate(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
task_description="Completed task",
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "terminate"
|
||||
assert result.data["episode_id"] == str(mock_episode.id)
|
||||
assert result.data["state_items_consolidated"] == 2
|
||||
mock_episodic.record_episode.assert_called_once()
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_terminate_cleans_up_working(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Terminate should clean up working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.terminate(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
consolidate_to_episodic=False,
|
||||
cleanup_working=True,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["items_cleared"] == 2
|
||||
assert mock_working.delete.call_count == 2
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerListCheckpoints:
|
||||
"""Tests for AgentLifecycleManager.list_checkpoints."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_list_checkpoints(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Should list available checkpoints."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(
|
||||
return_value=[
|
||||
"__checkpoint__ckpt-001",
|
||||
"__checkpoint__ckpt-002",
|
||||
"regular_key",
|
||||
]
|
||||
)
|
||||
mock_working.get = AsyncMock(
|
||||
return_value={
|
||||
"timestamp": "2024-01-01T00:00:00Z",
|
||||
"keys_count": 5,
|
||||
}
|
||||
)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
checkpoints = await lifecycle_manager.list_checkpoints(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert len(checkpoints) == 2
|
||||
assert checkpoints[0]["checkpoint_id"] == "ckpt-001"
|
||||
assert checkpoints[0]["keys_count"] == 5
|
||||
|
||||
|
||||
class TestGetLifecycleManager:
|
||||
"""Tests for factory function."""
|
||||
|
||||
async def test_creates_instance(self) -> None:
|
||||
"""Factory should create AgentLifecycleManager instance."""
|
||||
mock_session = MagicMock()
|
||||
|
||||
manager = await get_lifecycle_manager(mock_session)
|
||||
|
||||
assert isinstance(manager, AgentLifecycleManager)
|
||||
|
||||
async def test_with_custom_hooks(self) -> None:
|
||||
"""Factory should accept custom hooks."""
|
||||
mock_session = MagicMock()
|
||||
custom_hooks = LifecycleHooks()
|
||||
|
||||
manager = await get_lifecycle_manager(mock_session, hooks=custom_hooks)
|
||||
|
||||
assert manager.hooks is custom_hooks
|
||||
Reference in New Issue
Block a user