forked from cardosofelipe/fast-next-template
Auto-fixed linting errors and formatting issues: - Removed unused imports (F401): pytest, Any, AnalysisType, MemoryType, OutcomeType - Removed unused variable (F841): hooks variable in test - Applied consistent formatting across memory service and test files 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
473 lines
16 KiB
Python
473 lines
16 KiB
Python
# tests/unit/services/memory/integration/test_lifecycle.py
|
|
"""Tests for Agent Lifecycle Hooks."""
|
|
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from app.services.memory.integration.lifecycle import (
|
|
AgentLifecycleManager,
|
|
LifecycleEvent,
|
|
LifecycleHooks,
|
|
LifecycleResult,
|
|
get_lifecycle_manager,
|
|
)
|
|
from app.services.memory.types import Outcome
|
|
|
|
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_session() -> MagicMock:
|
|
"""Create mock database session."""
|
|
return MagicMock()
|
|
|
|
|
|
@pytest.fixture
|
|
def lifecycle_hooks() -> LifecycleHooks:
|
|
"""Create lifecycle hooks instance."""
|
|
return LifecycleHooks()
|
|
|
|
|
|
@pytest.fixture
|
|
def lifecycle_manager(mock_session: MagicMock) -> AgentLifecycleManager:
|
|
"""Create lifecycle manager instance."""
|
|
return AgentLifecycleManager(session=mock_session)
|
|
|
|
|
|
class TestLifecycleEvent:
|
|
"""Tests for LifecycleEvent dataclass."""
|
|
|
|
def test_creates_event(self) -> None:
|
|
"""Should create event with required fields."""
|
|
project_id = uuid4()
|
|
agent_id = uuid4()
|
|
|
|
event = LifecycleEvent(
|
|
event_type="spawn",
|
|
project_id=project_id,
|
|
agent_instance_id=agent_id,
|
|
)
|
|
|
|
assert event.event_type == "spawn"
|
|
assert event.project_id == project_id
|
|
assert event.agent_instance_id == agent_id
|
|
assert event.timestamp is not None
|
|
assert event.metadata == {}
|
|
|
|
def test_with_optional_fields(self) -> None:
|
|
"""Should include optional fields."""
|
|
event = LifecycleEvent(
|
|
event_type="terminate",
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
metadata={"reason": "completed"},
|
|
)
|
|
|
|
assert event.session_id == "sess-123"
|
|
assert event.metadata["reason"] == "completed"
|
|
|
|
|
|
class TestLifecycleResult:
|
|
"""Tests for LifecycleResult dataclass."""
|
|
|
|
def test_success_result(self) -> None:
|
|
"""Should create success result."""
|
|
result = LifecycleResult(
|
|
success=True,
|
|
event_type="spawn",
|
|
message="Agent spawned",
|
|
data={"session_id": "sess-123"},
|
|
duration_ms=10.5,
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.event_type == "spawn"
|
|
assert result.data["session_id"] == "sess-123"
|
|
|
|
def test_failure_result(self) -> None:
|
|
"""Should create failure result."""
|
|
result = LifecycleResult(
|
|
success=False,
|
|
event_type="resume",
|
|
message="Checkpoint not found",
|
|
)
|
|
|
|
assert result.success is False
|
|
assert result.message == "Checkpoint not found"
|
|
|
|
|
|
class TestLifecycleHooks:
|
|
"""Tests for LifecycleHooks class."""
|
|
|
|
def test_register_spawn_hook(self, lifecycle_hooks: LifecycleHooks) -> None:
|
|
"""Should register spawn hook."""
|
|
|
|
async def my_hook(event: LifecycleEvent) -> None:
|
|
pass
|
|
|
|
result = lifecycle_hooks.on_spawn(my_hook)
|
|
|
|
assert result is my_hook
|
|
assert my_hook in lifecycle_hooks._spawn_hooks
|
|
|
|
def test_register_all_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
|
|
"""Should register hooks for all event types."""
|
|
[
|
|
lifecycle_hooks.on_spawn(AsyncMock()),
|
|
lifecycle_hooks.on_pause(AsyncMock()),
|
|
lifecycle_hooks.on_resume(AsyncMock()),
|
|
lifecycle_hooks.on_terminate(AsyncMock()),
|
|
]
|
|
|
|
assert len(lifecycle_hooks._spawn_hooks) == 1
|
|
assert len(lifecycle_hooks._pause_hooks) == 1
|
|
assert len(lifecycle_hooks._resume_hooks) == 1
|
|
assert len(lifecycle_hooks._terminate_hooks) == 1
|
|
|
|
async def test_run_spawn_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
|
|
"""Should run all spawn hooks."""
|
|
hook1 = AsyncMock()
|
|
hook2 = AsyncMock()
|
|
lifecycle_hooks.on_spawn(hook1)
|
|
lifecycle_hooks.on_spawn(hook2)
|
|
|
|
event = LifecycleEvent(
|
|
event_type="spawn",
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
)
|
|
|
|
await lifecycle_hooks.run_spawn_hooks(event)
|
|
|
|
hook1.assert_called_once_with(event)
|
|
hook2.assert_called_once_with(event)
|
|
|
|
async def test_hook_failure_doesnt_stop_others(
|
|
self, lifecycle_hooks: LifecycleHooks
|
|
) -> None:
|
|
"""Hook failure should not stop other hooks from running."""
|
|
hook1 = AsyncMock(side_effect=ValueError("Oops"))
|
|
hook2 = AsyncMock()
|
|
lifecycle_hooks.on_pause(hook1)
|
|
lifecycle_hooks.on_pause(hook2)
|
|
|
|
event = LifecycleEvent(
|
|
event_type="pause",
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
)
|
|
|
|
await lifecycle_hooks.run_pause_hooks(event)
|
|
|
|
# hook2 should still be called even though hook1 failed
|
|
hook2.assert_called_once()
|
|
|
|
|
|
class TestAgentLifecycleManagerSpawn:
|
|
"""Tests for AgentLifecycleManager.spawn."""
|
|
|
|
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
|
async def test_spawn_creates_working_memory(
|
|
self,
|
|
mock_working_cls: MagicMock,
|
|
lifecycle_manager: AgentLifecycleManager,
|
|
) -> None:
|
|
"""Spawn should create working memory for session."""
|
|
mock_working = AsyncMock()
|
|
mock_working.set = AsyncMock()
|
|
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
|
|
|
result = await lifecycle_manager.spawn(
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.event_type == "spawn"
|
|
mock_working_cls.for_session.assert_called_once()
|
|
|
|
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
|
async def test_spawn_with_initial_state(
|
|
self,
|
|
mock_working_cls: MagicMock,
|
|
lifecycle_manager: AgentLifecycleManager,
|
|
) -> None:
|
|
"""Spawn should populate initial state."""
|
|
mock_working = AsyncMock()
|
|
mock_working.set = AsyncMock()
|
|
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
|
|
|
result = await lifecycle_manager.spawn(
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
initial_state={"key1": "value1", "key2": "value2"},
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.data["initial_items"] == 2
|
|
assert mock_working.set.call_count == 2
|
|
|
|
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
|
async def test_spawn_runs_hooks(
|
|
self,
|
|
mock_working_cls: MagicMock,
|
|
lifecycle_manager: AgentLifecycleManager,
|
|
) -> None:
|
|
"""Spawn should run registered hooks."""
|
|
mock_working = AsyncMock()
|
|
mock_working.set = AsyncMock()
|
|
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
|
|
|
hook = AsyncMock()
|
|
lifecycle_manager.hooks.on_spawn(hook)
|
|
|
|
await lifecycle_manager.spawn(
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
)
|
|
|
|
hook.assert_called_once()
|
|
|
|
|
|
class TestAgentLifecycleManagerPause:
|
|
"""Tests for AgentLifecycleManager.pause."""
|
|
|
|
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
|
async def test_pause_creates_checkpoint(
|
|
self,
|
|
mock_working_cls: MagicMock,
|
|
lifecycle_manager: AgentLifecycleManager,
|
|
) -> None:
|
|
"""Pause should create checkpoint of working memory."""
|
|
mock_working = AsyncMock()
|
|
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
|
mock_working.get = AsyncMock(return_value={"data": "test"})
|
|
mock_working.set = AsyncMock()
|
|
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
|
|
|
result = await lifecycle_manager.pause(
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
checkpoint_id="ckpt-001",
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.event_type == "pause"
|
|
assert result.data["checkpoint_id"] == "ckpt-001"
|
|
assert result.data["items_saved"] == 2
|
|
|
|
# Should save checkpoint with state
|
|
mock_working.set.assert_called_once()
|
|
call_args = mock_working.set.call_args
|
|
# Check positional arg (first arg is key)
|
|
assert "__checkpoint__ckpt-001" in call_args[0][0]
|
|
|
|
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
|
async def test_pause_generates_checkpoint_id(
|
|
self,
|
|
mock_working_cls: MagicMock,
|
|
lifecycle_manager: AgentLifecycleManager,
|
|
) -> None:
|
|
"""Pause should generate checkpoint ID if not provided."""
|
|
mock_working = AsyncMock()
|
|
mock_working.list_keys = AsyncMock(return_value=[])
|
|
mock_working.set = AsyncMock()
|
|
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
|
|
|
result = await lifecycle_manager.pause(
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
)
|
|
|
|
assert result.success is True
|
|
assert "checkpoint_id" in result.data
|
|
assert result.data["checkpoint_id"].startswith("checkpoint_")
|
|
|
|
|
|
class TestAgentLifecycleManagerResume:
|
|
"""Tests for AgentLifecycleManager.resume."""
|
|
|
|
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
|
async def test_resume_restores_checkpoint(
|
|
self,
|
|
mock_working_cls: MagicMock,
|
|
lifecycle_manager: AgentLifecycleManager,
|
|
) -> None:
|
|
"""Resume should restore working memory from checkpoint."""
|
|
checkpoint_data = {
|
|
"state": {"key1": "value1", "key2": "value2"},
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
"keys_count": 2,
|
|
}
|
|
|
|
mock_working = AsyncMock()
|
|
mock_working.list_keys = AsyncMock(return_value=[])
|
|
mock_working.get = AsyncMock(return_value=checkpoint_data)
|
|
mock_working.set = AsyncMock()
|
|
mock_working.delete = AsyncMock()
|
|
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
|
|
|
result = await lifecycle_manager.resume(
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
checkpoint_id="ckpt-001",
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.event_type == "resume"
|
|
assert result.data["items_restored"] == 2
|
|
|
|
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
|
async def test_resume_checkpoint_not_found(
|
|
self,
|
|
mock_working_cls: MagicMock,
|
|
lifecycle_manager: AgentLifecycleManager,
|
|
) -> None:
|
|
"""Resume should fail if checkpoint not found."""
|
|
mock_working = AsyncMock()
|
|
mock_working.get = AsyncMock(return_value=None)
|
|
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
|
|
|
result = await lifecycle_manager.resume(
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
checkpoint_id="nonexistent",
|
|
)
|
|
|
|
assert result.success is False
|
|
assert "not found" in result.message.lower()
|
|
|
|
|
|
class TestAgentLifecycleManagerTerminate:
|
|
"""Tests for AgentLifecycleManager.terminate."""
|
|
|
|
@patch("app.services.memory.integration.lifecycle.EpisodicMemory")
|
|
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
|
async def test_terminate_consolidates_to_episodic(
|
|
self,
|
|
mock_working_cls: MagicMock,
|
|
mock_episodic_cls: MagicMock,
|
|
lifecycle_manager: AgentLifecycleManager,
|
|
) -> None:
|
|
"""Terminate should consolidate working memory to episodic."""
|
|
mock_working = AsyncMock()
|
|
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
|
mock_working.get = AsyncMock(return_value="value")
|
|
mock_working.delete = AsyncMock()
|
|
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
|
|
|
mock_episode = MagicMock()
|
|
mock_episode.id = uuid4()
|
|
|
|
mock_episodic = AsyncMock()
|
|
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
|
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
|
|
|
result = await lifecycle_manager.terminate(
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
task_description="Completed task",
|
|
outcome=Outcome.SUCCESS,
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.event_type == "terminate"
|
|
assert result.data["episode_id"] == str(mock_episode.id)
|
|
assert result.data["state_items_consolidated"] == 2
|
|
mock_episodic.record_episode.assert_called_once()
|
|
|
|
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
|
async def test_terminate_cleans_up_working(
|
|
self,
|
|
mock_working_cls: MagicMock,
|
|
lifecycle_manager: AgentLifecycleManager,
|
|
) -> None:
|
|
"""Terminate should clean up working memory."""
|
|
mock_working = AsyncMock()
|
|
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
|
mock_working.get = AsyncMock(return_value="value")
|
|
mock_working.delete = AsyncMock()
|
|
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
|
|
|
result = await lifecycle_manager.terminate(
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
consolidate_to_episodic=False,
|
|
cleanup_working=True,
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.data["items_cleared"] == 2
|
|
assert mock_working.delete.call_count == 2
|
|
|
|
|
|
class TestAgentLifecycleManagerListCheckpoints:
|
|
"""Tests for AgentLifecycleManager.list_checkpoints."""
|
|
|
|
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
|
async def test_list_checkpoints(
|
|
self,
|
|
mock_working_cls: MagicMock,
|
|
lifecycle_manager: AgentLifecycleManager,
|
|
) -> None:
|
|
"""Should list available checkpoints."""
|
|
mock_working = AsyncMock()
|
|
mock_working.list_keys = AsyncMock(
|
|
return_value=[
|
|
"__checkpoint__ckpt-001",
|
|
"__checkpoint__ckpt-002",
|
|
"regular_key",
|
|
]
|
|
)
|
|
mock_working.get = AsyncMock(
|
|
return_value={
|
|
"timestamp": "2024-01-01T00:00:00Z",
|
|
"keys_count": 5,
|
|
}
|
|
)
|
|
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
|
|
|
checkpoints = await lifecycle_manager.list_checkpoints(
|
|
project_id=uuid4(),
|
|
agent_instance_id=uuid4(),
|
|
session_id="sess-123",
|
|
)
|
|
|
|
assert len(checkpoints) == 2
|
|
assert checkpoints[0]["checkpoint_id"] == "ckpt-001"
|
|
assert checkpoints[0]["keys_count"] == 5
|
|
|
|
|
|
class TestGetLifecycleManager:
|
|
"""Tests for factory function."""
|
|
|
|
async def test_creates_instance(self) -> None:
|
|
"""Factory should create AgentLifecycleManager instance."""
|
|
mock_session = MagicMock()
|
|
|
|
manager = await get_lifecycle_manager(mock_session)
|
|
|
|
assert isinstance(manager, AgentLifecycleManager)
|
|
|
|
async def test_with_custom_hooks(self) -> None:
|
|
"""Factory should accept custom hooks."""
|
|
mock_session = MagicMock()
|
|
custom_hooks = LifecycleHooks()
|
|
|
|
manager = await get_lifecycle_manager(mock_session, hooks=custom_hooks)
|
|
|
|
assert manager.hooks is custom_hooks
|