# tests/unit/services/memory/integration/test_lifecycle.py """Tests for Agent Lifecycle Hooks.""" from datetime import UTC, datetime from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 import pytest from app.services.memory.integration.lifecycle import ( AgentLifecycleManager, LifecycleEvent, LifecycleHooks, LifecycleResult, get_lifecycle_manager, ) from app.services.memory.types import Outcome pytestmark = pytest.mark.asyncio(loop_scope="function") @pytest.fixture def mock_session() -> MagicMock: """Create mock database session.""" return MagicMock() @pytest.fixture def lifecycle_hooks() -> LifecycleHooks: """Create lifecycle hooks instance.""" return LifecycleHooks() @pytest.fixture def lifecycle_manager(mock_session: MagicMock) -> AgentLifecycleManager: """Create lifecycle manager instance.""" return AgentLifecycleManager(session=mock_session) class TestLifecycleEvent: """Tests for LifecycleEvent dataclass.""" def test_creates_event(self) -> None: """Should create event with required fields.""" project_id = uuid4() agent_id = uuid4() event = LifecycleEvent( event_type="spawn", project_id=project_id, agent_instance_id=agent_id, ) assert event.event_type == "spawn" assert event.project_id == project_id assert event.agent_instance_id == agent_id assert event.timestamp is not None assert event.metadata == {} def test_with_optional_fields(self) -> None: """Should include optional fields.""" event = LifecycleEvent( event_type="terminate", project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", metadata={"reason": "completed"}, ) assert event.session_id == "sess-123" assert event.metadata["reason"] == "completed" class TestLifecycleResult: """Tests for LifecycleResult dataclass.""" def test_success_result(self) -> None: """Should create success result.""" result = LifecycleResult( success=True, event_type="spawn", message="Agent spawned", data={"session_id": "sess-123"}, duration_ms=10.5, ) assert result.success is True assert result.event_type == "spawn" assert result.data["session_id"] == "sess-123" def test_failure_result(self) -> None: """Should create failure result.""" result = LifecycleResult( success=False, event_type="resume", message="Checkpoint not found", ) assert result.success is False assert result.message == "Checkpoint not found" class TestLifecycleHooks: """Tests for LifecycleHooks class.""" def test_register_spawn_hook(self, lifecycle_hooks: LifecycleHooks) -> None: """Should register spawn hook.""" async def my_hook(event: LifecycleEvent) -> None: pass result = lifecycle_hooks.on_spawn(my_hook) assert result is my_hook assert my_hook in lifecycle_hooks._spawn_hooks def test_register_all_hooks(self, lifecycle_hooks: LifecycleHooks) -> None: """Should register hooks for all event types.""" hooks = [ lifecycle_hooks.on_spawn(AsyncMock()), lifecycle_hooks.on_pause(AsyncMock()), lifecycle_hooks.on_resume(AsyncMock()), lifecycle_hooks.on_terminate(AsyncMock()), ] assert len(lifecycle_hooks._spawn_hooks) == 1 assert len(lifecycle_hooks._pause_hooks) == 1 assert len(lifecycle_hooks._resume_hooks) == 1 assert len(lifecycle_hooks._terminate_hooks) == 1 async def test_run_spawn_hooks(self, lifecycle_hooks: LifecycleHooks) -> None: """Should run all spawn hooks.""" hook1 = AsyncMock() hook2 = AsyncMock() lifecycle_hooks.on_spawn(hook1) lifecycle_hooks.on_spawn(hook2) event = LifecycleEvent( event_type="spawn", project_id=uuid4(), agent_instance_id=uuid4(), ) await lifecycle_hooks.run_spawn_hooks(event) hook1.assert_called_once_with(event) hook2.assert_called_once_with(event) async def test_hook_failure_doesnt_stop_others( self, lifecycle_hooks: LifecycleHooks ) -> None: """Hook failure should not stop other hooks from running.""" hook1 = AsyncMock(side_effect=ValueError("Oops")) hook2 = AsyncMock() lifecycle_hooks.on_pause(hook1) lifecycle_hooks.on_pause(hook2) event = LifecycleEvent( event_type="pause", project_id=uuid4(), agent_instance_id=uuid4(), ) await lifecycle_hooks.run_pause_hooks(event) # hook2 should still be called even though hook1 failed hook2.assert_called_once() class TestAgentLifecycleManagerSpawn: """Tests for AgentLifecycleManager.spawn.""" @patch("app.services.memory.integration.lifecycle.WorkingMemory") async def test_spawn_creates_working_memory( self, mock_working_cls: MagicMock, lifecycle_manager: AgentLifecycleManager, ) -> None: """Spawn should create working memory for session.""" mock_working = AsyncMock() mock_working.set = AsyncMock() mock_working_cls.for_session = AsyncMock(return_value=mock_working) result = await lifecycle_manager.spawn( project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", ) assert result.success is True assert result.event_type == "spawn" mock_working_cls.for_session.assert_called_once() @patch("app.services.memory.integration.lifecycle.WorkingMemory") async def test_spawn_with_initial_state( self, mock_working_cls: MagicMock, lifecycle_manager: AgentLifecycleManager, ) -> None: """Spawn should populate initial state.""" mock_working = AsyncMock() mock_working.set = AsyncMock() mock_working_cls.for_session = AsyncMock(return_value=mock_working) result = await lifecycle_manager.spawn( project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", initial_state={"key1": "value1", "key2": "value2"}, ) assert result.success is True assert result.data["initial_items"] == 2 assert mock_working.set.call_count == 2 @patch("app.services.memory.integration.lifecycle.WorkingMemory") async def test_spawn_runs_hooks( self, mock_working_cls: MagicMock, lifecycle_manager: AgentLifecycleManager, ) -> None: """Spawn should run registered hooks.""" mock_working = AsyncMock() mock_working.set = AsyncMock() mock_working_cls.for_session = AsyncMock(return_value=mock_working) hook = AsyncMock() lifecycle_manager.hooks.on_spawn(hook) await lifecycle_manager.spawn( project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", ) hook.assert_called_once() class TestAgentLifecycleManagerPause: """Tests for AgentLifecycleManager.pause.""" @patch("app.services.memory.integration.lifecycle.WorkingMemory") async def test_pause_creates_checkpoint( self, mock_working_cls: MagicMock, lifecycle_manager: AgentLifecycleManager, ) -> None: """Pause should create checkpoint of working memory.""" mock_working = AsyncMock() mock_working.list_keys = AsyncMock(return_value=["key1", "key2"]) mock_working.get = AsyncMock(return_value={"data": "test"}) mock_working.set = AsyncMock() mock_working_cls.for_session = AsyncMock(return_value=mock_working) result = await lifecycle_manager.pause( project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", checkpoint_id="ckpt-001", ) assert result.success is True assert result.event_type == "pause" assert result.data["checkpoint_id"] == "ckpt-001" assert result.data["items_saved"] == 2 # Should save checkpoint with state mock_working.set.assert_called_once() call_args = mock_working.set.call_args # Check positional arg (first arg is key) assert "__checkpoint__ckpt-001" in call_args[0][0] @patch("app.services.memory.integration.lifecycle.WorkingMemory") async def test_pause_generates_checkpoint_id( self, mock_working_cls: MagicMock, lifecycle_manager: AgentLifecycleManager, ) -> None: """Pause should generate checkpoint ID if not provided.""" mock_working = AsyncMock() mock_working.list_keys = AsyncMock(return_value=[]) mock_working.set = AsyncMock() mock_working_cls.for_session = AsyncMock(return_value=mock_working) result = await lifecycle_manager.pause( project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", ) assert result.success is True assert "checkpoint_id" in result.data assert result.data["checkpoint_id"].startswith("checkpoint_") class TestAgentLifecycleManagerResume: """Tests for AgentLifecycleManager.resume.""" @patch("app.services.memory.integration.lifecycle.WorkingMemory") async def test_resume_restores_checkpoint( self, mock_working_cls: MagicMock, lifecycle_manager: AgentLifecycleManager, ) -> None: """Resume should restore working memory from checkpoint.""" checkpoint_data = { "state": {"key1": "value1", "key2": "value2"}, "timestamp": datetime.now(UTC).isoformat(), "keys_count": 2, } mock_working = AsyncMock() mock_working.list_keys = AsyncMock(return_value=[]) mock_working.get = AsyncMock(return_value=checkpoint_data) mock_working.set = AsyncMock() mock_working.delete = AsyncMock() mock_working_cls.for_session = AsyncMock(return_value=mock_working) result = await lifecycle_manager.resume( project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", checkpoint_id="ckpt-001", ) assert result.success is True assert result.event_type == "resume" assert result.data["items_restored"] == 2 @patch("app.services.memory.integration.lifecycle.WorkingMemory") async def test_resume_checkpoint_not_found( self, mock_working_cls: MagicMock, lifecycle_manager: AgentLifecycleManager, ) -> None: """Resume should fail if checkpoint not found.""" mock_working = AsyncMock() mock_working.get = AsyncMock(return_value=None) mock_working_cls.for_session = AsyncMock(return_value=mock_working) result = await lifecycle_manager.resume( project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", checkpoint_id="nonexistent", ) assert result.success is False assert "not found" in result.message.lower() class TestAgentLifecycleManagerTerminate: """Tests for AgentLifecycleManager.terminate.""" @patch("app.services.memory.integration.lifecycle.EpisodicMemory") @patch("app.services.memory.integration.lifecycle.WorkingMemory") async def test_terminate_consolidates_to_episodic( self, mock_working_cls: MagicMock, mock_episodic_cls: MagicMock, lifecycle_manager: AgentLifecycleManager, ) -> None: """Terminate should consolidate working memory to episodic.""" mock_working = AsyncMock() mock_working.list_keys = AsyncMock(return_value=["key1", "key2"]) mock_working.get = AsyncMock(return_value="value") mock_working.delete = AsyncMock() mock_working_cls.for_session = AsyncMock(return_value=mock_working) mock_episode = MagicMock() mock_episode.id = uuid4() mock_episodic = AsyncMock() mock_episodic.record_episode = AsyncMock(return_value=mock_episode) mock_episodic_cls.create = AsyncMock(return_value=mock_episodic) result = await lifecycle_manager.terminate( project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", task_description="Completed task", outcome=Outcome.SUCCESS, ) assert result.success is True assert result.event_type == "terminate" assert result.data["episode_id"] == str(mock_episode.id) assert result.data["state_items_consolidated"] == 2 mock_episodic.record_episode.assert_called_once() @patch("app.services.memory.integration.lifecycle.WorkingMemory") async def test_terminate_cleans_up_working( self, mock_working_cls: MagicMock, lifecycle_manager: AgentLifecycleManager, ) -> None: """Terminate should clean up working memory.""" mock_working = AsyncMock() mock_working.list_keys = AsyncMock(return_value=["key1", "key2"]) mock_working.get = AsyncMock(return_value="value") mock_working.delete = AsyncMock() mock_working_cls.for_session = AsyncMock(return_value=mock_working) result = await lifecycle_manager.terminate( project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", consolidate_to_episodic=False, cleanup_working=True, ) assert result.success is True assert result.data["items_cleared"] == 2 assert mock_working.delete.call_count == 2 class TestAgentLifecycleManagerListCheckpoints: """Tests for AgentLifecycleManager.list_checkpoints.""" @patch("app.services.memory.integration.lifecycle.WorkingMemory") async def test_list_checkpoints( self, mock_working_cls: MagicMock, lifecycle_manager: AgentLifecycleManager, ) -> None: """Should list available checkpoints.""" mock_working = AsyncMock() mock_working.list_keys = AsyncMock( return_value=[ "__checkpoint__ckpt-001", "__checkpoint__ckpt-002", "regular_key", ] ) mock_working.get = AsyncMock( return_value={ "timestamp": "2024-01-01T00:00:00Z", "keys_count": 5, } ) mock_working_cls.for_session = AsyncMock(return_value=mock_working) checkpoints = await lifecycle_manager.list_checkpoints( project_id=uuid4(), agent_instance_id=uuid4(), session_id="sess-123", ) assert len(checkpoints) == 2 assert checkpoints[0]["checkpoint_id"] == "ckpt-001" assert checkpoints[0]["keys_count"] == 5 class TestGetLifecycleManager: """Tests for factory function.""" async def test_creates_instance(self) -> None: """Factory should create AgentLifecycleManager instance.""" mock_session = MagicMock() manager = await get_lifecycle_manager(mock_session) assert isinstance(manager, AgentLifecycleManager) async def test_with_custom_hooks(self) -> None: """Factory should accept custom hooks.""" mock_session = MagicMock() custom_hooks = LifecycleHooks() manager = await get_lifecycle_manager(mock_session, hooks=custom_hooks) assert manager.hooks is custom_hooks