feat(memory): integrate memory system with context engine (#97)

## Changes

### New Context Type
- Add MEMORY to ContextType enum for agent memory context
- Create MemoryContext class with subtypes (working, episodic, semantic, procedural)
- Factory methods: from_working_memory, from_episodic_memory, from_semantic_memory, from_procedural_memory

### Memory Context Source
- MemoryContextSource service fetches relevant memories for context assembly
- Configurable fetch limits per memory type
- Parallel fetching from all memory types

### Agent Lifecycle Hooks
- AgentLifecycleManager handles spawn, pause, resume, terminate events
- spawn: Initialize working memory with optional initial state
- pause: Create checkpoint of working memory
- resume: Restore from checkpoint
- terminate: Consolidate working memory to episodic memory
- LifecycleHooks for custom extension points

### Context Engine Integration
- Add memory_query parameter to assemble_context()
- Add session_id and agent_type_id for memory scoping
- Memory budget allocation (15% by default)
- set_memory_source() for runtime configuration

### Tests
- 48 new tests for MemoryContext, MemoryContextSource, and lifecycle hooks
- All 108 memory-related tests passing
- mypy and ruff checks passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 03:49:22 +01:00
parent 0b24d4c6cc
commit 30e5c68304
13 changed files with 2509 additions and 6 deletions

View File

@@ -0,0 +1,262 @@
# tests/unit/services/context/types/test_memory.py
"""Tests for MemoryContext type."""
from datetime import UTC, datetime
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from app.services.context.types import ContextType
from app.services.context.types.memory import MemoryContext, MemorySubtype
class TestMemorySubtype:
"""Tests for MemorySubtype enum."""
def test_all_types_defined(self) -> None:
"""All memory subtypes should be defined."""
assert MemorySubtype.WORKING == "working"
assert MemorySubtype.EPISODIC == "episodic"
assert MemorySubtype.SEMANTIC == "semantic"
assert MemorySubtype.PROCEDURAL == "procedural"
def test_enum_values(self) -> None:
"""Enum values should match strings."""
assert MemorySubtype.WORKING.value == "working"
assert MemorySubtype("episodic") == MemorySubtype.EPISODIC
class TestMemoryContext:
"""Tests for MemoryContext class."""
def test_get_type_returns_memory(self) -> None:
"""get_type should return MEMORY."""
ctx = MemoryContext(content="test", source="test_source")
assert ctx.get_type() == ContextType.MEMORY
def test_default_values(self) -> None:
"""Default values should be set correctly."""
ctx = MemoryContext(content="test", source="test_source")
assert ctx.memory_subtype == MemorySubtype.EPISODIC
assert ctx.memory_id is None
assert ctx.relevance_score == 0.0
assert ctx.importance == 0.5
def test_to_dict_includes_memory_fields(self) -> None:
"""to_dict should include memory-specific fields."""
ctx = MemoryContext(
content="test content",
source="test_source",
memory_subtype=MemorySubtype.SEMANTIC,
memory_id="mem-123",
relevance_score=0.8,
subject="User",
predicate="prefers",
object_value="dark mode",
)
data = ctx.to_dict()
assert data["memory_subtype"] == "semantic"
assert data["memory_id"] == "mem-123"
assert data["relevance_score"] == 0.8
assert data["subject"] == "User"
assert data["predicate"] == "prefers"
assert data["object_value"] == "dark mode"
def test_from_dict(self) -> None:
"""from_dict should create correct MemoryContext."""
data = {
"content": "test content",
"source": "test_source",
"timestamp": "2024-01-01T00:00:00+00:00",
"memory_subtype": "semantic",
"memory_id": "mem-123",
"relevance_score": 0.8,
"subject": "Test",
}
ctx = MemoryContext.from_dict(data)
assert ctx.content == "test content"
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
assert ctx.memory_id == "mem-123"
assert ctx.subject == "Test"
class TestMemoryContextFromWorkingMemory:
"""Tests for MemoryContext.from_working_memory."""
def test_creates_working_memory_context(self) -> None:
"""Should create working memory context from key/value."""
ctx = MemoryContext.from_working_memory(
key="user_preferences",
value={"theme": "dark"},
source="working:sess-123",
query="preferences",
)
assert ctx.memory_subtype == MemorySubtype.WORKING
assert ctx.key == "user_preferences"
assert "{'theme': 'dark'}" in ctx.content
assert ctx.relevance_score == 1.0 # Working memory is always relevant
assert ctx.importance == 0.8 # Higher importance
def test_string_value(self) -> None:
"""Should handle string values."""
ctx = MemoryContext.from_working_memory(
key="current_task",
value="Build authentication",
)
assert ctx.content == "Build authentication"
class TestMemoryContextFromEpisodicMemory:
"""Tests for MemoryContext.from_episodic_memory."""
def test_creates_episodic_memory_context(self) -> None:
"""Should create episodic memory context from episode."""
episode = MagicMock()
episode.id = uuid4()
episode.task_description = "Implemented login feature"
episode.task_type = "feature_implementation"
episode.outcome = MagicMock(value="success")
episode.importance_score = 0.9
episode.session_id = "sess-123"
episode.occurred_at = datetime.now(UTC)
episode.lessons_learned = ["Use proper validation"]
ctx = MemoryContext.from_episodic_memory(episode, query="login")
assert ctx.memory_subtype == MemorySubtype.EPISODIC
assert ctx.memory_id == str(episode.id)
assert ctx.content == "Implemented login feature"
assert ctx.task_type == "feature_implementation"
assert ctx.outcome == "success"
assert ctx.importance == 0.9
def test_handles_missing_outcome(self) -> None:
"""Should handle episodes with no outcome."""
episode = MagicMock()
episode.id = uuid4()
episode.task_description = "WIP task"
episode.outcome = None
episode.importance_score = 0.5
episode.occurred_at = None
ctx = MemoryContext.from_episodic_memory(episode)
assert ctx.outcome is None
class TestMemoryContextFromSemanticMemory:
"""Tests for MemoryContext.from_semantic_memory."""
def test_creates_semantic_memory_context(self) -> None:
"""Should create semantic memory context from fact."""
fact = MagicMock()
fact.id = uuid4()
fact.subject = "User"
fact.predicate = "prefers"
fact.object = "dark mode"
fact.confidence = 0.95
ctx = MemoryContext.from_semantic_memory(fact, query="user preferences")
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
assert ctx.memory_id == str(fact.id)
assert ctx.content == "User prefers dark mode"
assert ctx.subject == "User"
assert ctx.predicate == "prefers"
assert ctx.object_value == "dark mode"
assert ctx.relevance_score == 0.95
class TestMemoryContextFromProceduralMemory:
"""Tests for MemoryContext.from_procedural_memory."""
def test_creates_procedural_memory_context(self) -> None:
"""Should create procedural memory context from procedure."""
procedure = MagicMock()
procedure.id = uuid4()
procedure.name = "Deploy to Production"
procedure.trigger_pattern = "When deploying to production"
procedure.steps = [
{"action": "run_tests"},
{"action": "build_docker"},
{"action": "deploy"},
]
procedure.success_rate = 0.85
procedure.success_count = 10
procedure.failure_count = 2
ctx = MemoryContext.from_procedural_memory(procedure, query="deploy")
assert ctx.memory_subtype == MemorySubtype.PROCEDURAL
assert ctx.memory_id == str(procedure.id)
assert "Deploy to Production" in ctx.content
assert "When deploying to production" in ctx.content
assert ctx.trigger == "When deploying to production"
assert ctx.success_rate == 0.85
assert ctx.metadata["steps_count"] == 3
assert ctx.metadata["execution_count"] == 12
class TestMemoryContextHelpers:
"""Tests for MemoryContext helper methods."""
def test_is_working_memory(self) -> None:
"""is_working_memory should return True for working memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.WORKING,
)
assert ctx.is_working_memory() is True
assert ctx.is_episodic_memory() is False
def test_is_episodic_memory(self) -> None:
"""is_episodic_memory should return True for episodic memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.EPISODIC,
)
assert ctx.is_episodic_memory() is True
assert ctx.is_semantic_memory() is False
def test_is_semantic_memory(self) -> None:
"""is_semantic_memory should return True for semantic memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.SEMANTIC,
)
assert ctx.is_semantic_memory() is True
assert ctx.is_procedural_memory() is False
def test_is_procedural_memory(self) -> None:
"""is_procedural_memory should return True for procedural memory."""
ctx = MemoryContext(
content="test",
source="test",
memory_subtype=MemorySubtype.PROCEDURAL,
)
assert ctx.is_procedural_memory() is True
assert ctx.is_working_memory() is False
def test_get_formatted_source(self) -> None:
"""get_formatted_source should return formatted string."""
ctx = MemoryContext(
content="test",
source="episodic:12345678-1234-1234-1234-123456789012",
memory_subtype=MemorySubtype.EPISODIC,
memory_id="12345678-1234-1234-1234-123456789012",
)
formatted = ctx.get_formatted_source()
assert "[episodic]" in formatted
assert "12345678..." in formatted

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/integration/__init__.py
"""Tests for memory integration module."""

View File

@@ -0,0 +1,322 @@
# tests/unit/services/memory/integration/test_context_source.py
"""Tests for MemoryContextSource service."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from app.services.context.types.memory import MemorySubtype
from app.services.memory.integration.context_source import (
MemoryContextSource,
MemoryFetchConfig,
MemoryFetchResult,
get_memory_context_source,
)
pytestmark = pytest.mark.asyncio(loop_scope="function")
@pytest.fixture
def mock_session() -> MagicMock:
"""Create mock database session."""
return MagicMock()
@pytest.fixture
def context_source(mock_session: MagicMock) -> MemoryContextSource:
"""Create MemoryContextSource instance."""
return MemoryContextSource(session=mock_session)
class TestMemoryFetchConfig:
"""Tests for MemoryFetchConfig."""
def test_default_values(self) -> None:
"""Default config values should be set correctly."""
config = MemoryFetchConfig()
assert config.working_limit == 10
assert config.episodic_limit == 10
assert config.semantic_limit == 15
assert config.procedural_limit == 5
assert config.episodic_days_back == 30
assert config.min_relevance == 0.3
assert config.include_working is True
assert config.include_episodic is True
assert config.include_semantic is True
assert config.include_procedural is True
def test_custom_values(self) -> None:
"""Custom config values should be respected."""
config = MemoryFetchConfig(
working_limit=5,
include_working=False,
)
assert config.working_limit == 5
assert config.include_working is False
class TestMemoryFetchResult:
"""Tests for MemoryFetchResult."""
def test_stores_results(self) -> None:
"""Result should store contexts and metadata."""
result = MemoryFetchResult(
contexts=[],
by_type={"working": 0, "episodic": 5, "semantic": 3, "procedural": 0},
fetch_time_ms=15.5,
query="test query",
)
assert result.contexts == []
assert result.by_type["episodic"] == 5
assert result.fetch_time_ms == 15.5
assert result.query == "test query"
class TestMemoryContextSource:
"""Tests for MemoryContextSource service."""
async def test_fetch_context_empty_when_no_sources(
self,
context_source: MemoryContextSource,
) -> None:
"""fetch_context should return empty when all sources fail."""
config = MemoryFetchConfig(
include_working=False,
include_episodic=False,
include_semantic=False,
include_procedural=False,
)
result = await context_source.fetch_context(
query="test",
project_id=uuid4(),
config=config,
)
assert len(result.contexts) == 0
assert result.by_type == {
"working": 0,
"episodic": 0,
"semantic": 0,
"procedural": 0,
}
@patch("app.services.memory.integration.context_source.WorkingMemory")
async def test_fetch_working_memory(
self,
mock_working_cls: MagicMock,
context_source: MemoryContextSource,
) -> None:
"""Should fetch working memory when session_id provided."""
# Setup mock - both keys should match the query "task"
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["current_task", "task_state"])
mock_working.get = AsyncMock(side_effect=lambda k: {"key": k, "value": "test"})
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
config = MemoryFetchConfig(
include_episodic=False,
include_semantic=False,
include_procedural=False,
)
result = await context_source.fetch_context(
query="task", # Both keys contain "task"
project_id=uuid4(),
session_id="sess-123",
config=config,
)
assert result.by_type["working"] == 2
assert all(
c.memory_subtype == MemorySubtype.WORKING for c in result.contexts
)
@patch("app.services.memory.integration.context_source.EpisodicMemory")
async def test_fetch_episodic_memory(
self,
mock_episodic_cls: MagicMock,
context_source: MemoryContextSource,
) -> None:
"""Should fetch episodic memory."""
# Setup mock episode
mock_episode = MagicMock()
mock_episode.id = uuid4()
mock_episode.task_description = "Completed login feature"
mock_episode.task_type = "feature"
mock_episode.outcome = MagicMock(value="success")
mock_episode.importance_score = 0.8
mock_episode.occurred_at = datetime.now(UTC)
mock_episode.lessons_learned = []
mock_episodic = AsyncMock()
mock_episodic.search_similar = AsyncMock(return_value=[mock_episode])
mock_episodic.get_recent = AsyncMock(return_value=[])
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
config = MemoryFetchConfig(
include_working=False,
include_semantic=False,
include_procedural=False,
)
result = await context_source.fetch_context(
query="login",
project_id=uuid4(),
config=config,
)
assert result.by_type["episodic"] == 1
assert result.contexts[0].memory_subtype == MemorySubtype.EPISODIC
assert "Completed login feature" in result.contexts[0].content
@patch("app.services.memory.integration.context_source.SemanticMemory")
async def test_fetch_semantic_memory(
self,
mock_semantic_cls: MagicMock,
context_source: MemoryContextSource,
) -> None:
"""Should fetch semantic memory."""
# Setup mock fact
mock_fact = MagicMock()
mock_fact.id = uuid4()
mock_fact.subject = "User"
mock_fact.predicate = "prefers"
mock_fact.object = "dark mode"
mock_fact.confidence = 0.9
mock_semantic = AsyncMock()
mock_semantic.search_facts = AsyncMock(return_value=[mock_fact])
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
config = MemoryFetchConfig(
include_working=False,
include_episodic=False,
include_procedural=False,
)
result = await context_source.fetch_context(
query="preferences",
project_id=uuid4(),
config=config,
)
assert result.by_type["semantic"] == 1
assert result.contexts[0].memory_subtype == MemorySubtype.SEMANTIC
assert "User prefers dark mode" in result.contexts[0].content
@patch("app.services.memory.integration.context_source.ProceduralMemory")
async def test_fetch_procedural_memory(
self,
mock_procedural_cls: MagicMock,
context_source: MemoryContextSource,
) -> None:
"""Should fetch procedural memory."""
# Setup mock procedure
mock_proc = MagicMock()
mock_proc.id = uuid4()
mock_proc.name = "Deploy"
mock_proc.trigger_pattern = "When deploying"
mock_proc.steps = [{"action": "build"}, {"action": "test"}]
mock_proc.success_rate = 0.9
mock_proc.success_count = 9
mock_proc.failure_count = 1
mock_procedural = AsyncMock()
mock_procedural.find_matching = AsyncMock(return_value=[mock_proc])
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
config = MemoryFetchConfig(
include_working=False,
include_episodic=False,
include_semantic=False,
)
result = await context_source.fetch_context(
query="deploy",
project_id=uuid4(),
config=config,
)
assert result.by_type["procedural"] == 1
assert result.contexts[0].memory_subtype == MemorySubtype.PROCEDURAL
assert "Deploy" in result.contexts[0].content
async def test_results_sorted_by_relevance(
self,
context_source: MemoryContextSource,
) -> None:
"""Results should be sorted by relevance score."""
with patch.object(
context_source, "_fetch_episodic"
) as mock_ep, patch.object(
context_source, "_fetch_semantic"
) as mock_sem:
# Create contexts with different relevance scores
from app.services.context.types.memory import MemoryContext
ctx_low = MemoryContext(
content="low relevance",
source="test",
relevance_score=0.3,
)
ctx_high = MemoryContext(
content="high relevance",
source="test",
relevance_score=0.9,
)
mock_ep.return_value = [ctx_low]
mock_sem.return_value = [ctx_high]
config = MemoryFetchConfig(
include_working=False,
include_procedural=False,
)
result = await context_source.fetch_context(
query="test",
project_id=uuid4(),
config=config,
)
# Higher relevance should come first
assert result.contexts[0].relevance_score == 0.9
assert result.contexts[1].relevance_score == 0.3
@patch("app.services.memory.integration.context_source.WorkingMemory")
async def test_fetch_all_working(
self,
mock_working_cls: MagicMock,
context_source: MemoryContextSource,
) -> None:
"""fetch_all_working should return all working memory items."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["key1", "key2", "key3"])
mock_working.get = AsyncMock(return_value="value")
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
contexts = await context_source.fetch_all_working(
session_id="sess-123",
project_id=uuid4(),
)
assert len(contexts) == 3
assert all(c.memory_subtype == MemorySubtype.WORKING for c in contexts)
class TestGetMemoryContextSource:
"""Tests for factory function."""
async def test_creates_instance(self) -> None:
"""Factory should create MemoryContextSource instance."""
mock_session = MagicMock()
source = await get_memory_context_source(mock_session)
assert isinstance(source, MemoryContextSource)

View File

@@ -0,0 +1,471 @@
# tests/unit/services/memory/integration/test_lifecycle.py
"""Tests for Agent Lifecycle Hooks."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from app.services.memory.integration.lifecycle import (
AgentLifecycleManager,
LifecycleEvent,
LifecycleHooks,
LifecycleResult,
get_lifecycle_manager,
)
from app.services.memory.types import Outcome
pytestmark = pytest.mark.asyncio(loop_scope="function")
@pytest.fixture
def mock_session() -> MagicMock:
"""Create mock database session."""
return MagicMock()
@pytest.fixture
def lifecycle_hooks() -> LifecycleHooks:
"""Create lifecycle hooks instance."""
return LifecycleHooks()
@pytest.fixture
def lifecycle_manager(mock_session: MagicMock) -> AgentLifecycleManager:
"""Create lifecycle manager instance."""
return AgentLifecycleManager(session=mock_session)
class TestLifecycleEvent:
"""Tests for LifecycleEvent dataclass."""
def test_creates_event(self) -> None:
"""Should create event with required fields."""
project_id = uuid4()
agent_id = uuid4()
event = LifecycleEvent(
event_type="spawn",
project_id=project_id,
agent_instance_id=agent_id,
)
assert event.event_type == "spawn"
assert event.project_id == project_id
assert event.agent_instance_id == agent_id
assert event.timestamp is not None
assert event.metadata == {}
def test_with_optional_fields(self) -> None:
"""Should include optional fields."""
event = LifecycleEvent(
event_type="terminate",
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
metadata={"reason": "completed"},
)
assert event.session_id == "sess-123"
assert event.metadata["reason"] == "completed"
class TestLifecycleResult:
"""Tests for LifecycleResult dataclass."""
def test_success_result(self) -> None:
"""Should create success result."""
result = LifecycleResult(
success=True,
event_type="spawn",
message="Agent spawned",
data={"session_id": "sess-123"},
duration_ms=10.5,
)
assert result.success is True
assert result.event_type == "spawn"
assert result.data["session_id"] == "sess-123"
def test_failure_result(self) -> None:
"""Should create failure result."""
result = LifecycleResult(
success=False,
event_type="resume",
message="Checkpoint not found",
)
assert result.success is False
assert result.message == "Checkpoint not found"
class TestLifecycleHooks:
"""Tests for LifecycleHooks class."""
def test_register_spawn_hook(self, lifecycle_hooks: LifecycleHooks) -> None:
"""Should register spawn hook."""
async def my_hook(event: LifecycleEvent) -> None:
pass
result = lifecycle_hooks.on_spawn(my_hook)
assert result is my_hook
assert my_hook in lifecycle_hooks._spawn_hooks
def test_register_all_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
"""Should register hooks for all event types."""
hooks = [
lifecycle_hooks.on_spawn(AsyncMock()),
lifecycle_hooks.on_pause(AsyncMock()),
lifecycle_hooks.on_resume(AsyncMock()),
lifecycle_hooks.on_terminate(AsyncMock()),
]
assert len(lifecycle_hooks._spawn_hooks) == 1
assert len(lifecycle_hooks._pause_hooks) == 1
assert len(lifecycle_hooks._resume_hooks) == 1
assert len(lifecycle_hooks._terminate_hooks) == 1
async def test_run_spawn_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
"""Should run all spawn hooks."""
hook1 = AsyncMock()
hook2 = AsyncMock()
lifecycle_hooks.on_spawn(hook1)
lifecycle_hooks.on_spawn(hook2)
event = LifecycleEvent(
event_type="spawn",
project_id=uuid4(),
agent_instance_id=uuid4(),
)
await lifecycle_hooks.run_spawn_hooks(event)
hook1.assert_called_once_with(event)
hook2.assert_called_once_with(event)
async def test_hook_failure_doesnt_stop_others(
self, lifecycle_hooks: LifecycleHooks
) -> None:
"""Hook failure should not stop other hooks from running."""
hook1 = AsyncMock(side_effect=ValueError("Oops"))
hook2 = AsyncMock()
lifecycle_hooks.on_pause(hook1)
lifecycle_hooks.on_pause(hook2)
event = LifecycleEvent(
event_type="pause",
project_id=uuid4(),
agent_instance_id=uuid4(),
)
await lifecycle_hooks.run_pause_hooks(event)
# hook2 should still be called even though hook1 failed
hook2.assert_called_once()
class TestAgentLifecycleManagerSpawn:
"""Tests for AgentLifecycleManager.spawn."""
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_spawn_creates_working_memory(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Spawn should create working memory for session."""
mock_working = AsyncMock()
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.spawn(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
)
assert result.success is True
assert result.event_type == "spawn"
mock_working_cls.for_session.assert_called_once()
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_spawn_with_initial_state(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Spawn should populate initial state."""
mock_working = AsyncMock()
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.spawn(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
initial_state={"key1": "value1", "key2": "value2"},
)
assert result.success is True
assert result.data["initial_items"] == 2
assert mock_working.set.call_count == 2
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_spawn_runs_hooks(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Spawn should run registered hooks."""
mock_working = AsyncMock()
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
hook = AsyncMock()
lifecycle_manager.hooks.on_spawn(hook)
await lifecycle_manager.spawn(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
)
hook.assert_called_once()
class TestAgentLifecycleManagerPause:
"""Tests for AgentLifecycleManager.pause."""
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_pause_creates_checkpoint(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Pause should create checkpoint of working memory."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
mock_working.get = AsyncMock(return_value={"data": "test"})
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.pause(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
checkpoint_id="ckpt-001",
)
assert result.success is True
assert result.event_type == "pause"
assert result.data["checkpoint_id"] == "ckpt-001"
assert result.data["items_saved"] == 2
# Should save checkpoint with state
mock_working.set.assert_called_once()
call_args = mock_working.set.call_args
# Check positional arg (first arg is key)
assert "__checkpoint__ckpt-001" in call_args[0][0]
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_pause_generates_checkpoint_id(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Pause should generate checkpoint ID if not provided."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=[])
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.pause(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
)
assert result.success is True
assert "checkpoint_id" in result.data
assert result.data["checkpoint_id"].startswith("checkpoint_")
class TestAgentLifecycleManagerResume:
"""Tests for AgentLifecycleManager.resume."""
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_resume_restores_checkpoint(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Resume should restore working memory from checkpoint."""
checkpoint_data = {
"state": {"key1": "value1", "key2": "value2"},
"timestamp": datetime.now(UTC).isoformat(),
"keys_count": 2,
}
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=[])
mock_working.get = AsyncMock(return_value=checkpoint_data)
mock_working.set = AsyncMock()
mock_working.delete = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.resume(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
checkpoint_id="ckpt-001",
)
assert result.success is True
assert result.event_type == "resume"
assert result.data["items_restored"] == 2
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_resume_checkpoint_not_found(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Resume should fail if checkpoint not found."""
mock_working = AsyncMock()
mock_working.get = AsyncMock(return_value=None)
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.resume(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
checkpoint_id="nonexistent",
)
assert result.success is False
assert "not found" in result.message.lower()
class TestAgentLifecycleManagerTerminate:
"""Tests for AgentLifecycleManager.terminate."""
@patch("app.services.memory.integration.lifecycle.EpisodicMemory")
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_terminate_consolidates_to_episodic(
self,
mock_working_cls: MagicMock,
mock_episodic_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Terminate should consolidate working memory to episodic."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
mock_working.get = AsyncMock(return_value="value")
mock_working.delete = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
mock_episode = MagicMock()
mock_episode.id = uuid4()
mock_episodic = AsyncMock()
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
result = await lifecycle_manager.terminate(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
task_description="Completed task",
outcome=Outcome.SUCCESS,
)
assert result.success is True
assert result.event_type == "terminate"
assert result.data["episode_id"] == str(mock_episode.id)
assert result.data["state_items_consolidated"] == 2
mock_episodic.record_episode.assert_called_once()
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_terminate_cleans_up_working(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Terminate should clean up working memory."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
mock_working.get = AsyncMock(return_value="value")
mock_working.delete = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await lifecycle_manager.terminate(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
consolidate_to_episodic=False,
cleanup_working=True,
)
assert result.success is True
assert result.data["items_cleared"] == 2
assert mock_working.delete.call_count == 2
class TestAgentLifecycleManagerListCheckpoints:
"""Tests for AgentLifecycleManager.list_checkpoints."""
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
async def test_list_checkpoints(
self,
mock_working_cls: MagicMock,
lifecycle_manager: AgentLifecycleManager,
) -> None:
"""Should list available checkpoints."""
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(
return_value=[
"__checkpoint__ckpt-001",
"__checkpoint__ckpt-002",
"regular_key",
]
)
mock_working.get = AsyncMock(
return_value={
"timestamp": "2024-01-01T00:00:00Z",
"keys_count": 5,
}
)
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
checkpoints = await lifecycle_manager.list_checkpoints(
project_id=uuid4(),
agent_instance_id=uuid4(),
session_id="sess-123",
)
assert len(checkpoints) == 2
assert checkpoints[0]["checkpoint_id"] == "ckpt-001"
assert checkpoints[0]["keys_count"] == 5
class TestGetLifecycleManager:
"""Tests for factory function."""
async def test_creates_instance(self) -> None:
"""Factory should create AgentLifecycleManager instance."""
mock_session = MagicMock()
manager = await get_lifecycle_manager(mock_session)
assert isinstance(manager, AgentLifecycleManager)
async def test_with_custom_hooks(self) -> None:
"""Factory should accept custom hooks."""
mock_session = MagicMock()
custom_hooks = LifecycleHooks()
manager = await get_lifecycle_manager(mock_session, hooks=custom_hooks)
assert manager.hooks is custom_hooks