feat(memory): integrate memory system with context engine (#97)
## Changes ### New Context Type - Add MEMORY to ContextType enum for agent memory context - Create MemoryContext class with subtypes (working, episodic, semantic, procedural) - Factory methods: from_working_memory, from_episodic_memory, from_semantic_memory, from_procedural_memory ### Memory Context Source - MemoryContextSource service fetches relevant memories for context assembly - Configurable fetch limits per memory type - Parallel fetching from all memory types ### Agent Lifecycle Hooks - AgentLifecycleManager handles spawn, pause, resume, terminate events - spawn: Initialize working memory with optional initial state - pause: Create checkpoint of working memory - resume: Restore from checkpoint - terminate: Consolidate working memory to episodic memory - LifecycleHooks for custom extension points ### Context Engine Integration - Add memory_query parameter to assemble_context() - Add session_id and agent_type_id for memory scoping - Memory budget allocation (15% by default) - set_memory_source() for runtime configuration ### Tests - 48 new tests for MemoryContext, MemoryContextSource, and lifecycle hooks - All 108 memory-related tests passing - mypy and ruff checks passing 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
262
backend/tests/unit/services/context/types/test_memory.py
Normal file
262
backend/tests/unit/services/context/types/test_memory.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# tests/unit/services/context/types/test_memory.py
|
||||
"""Tests for MemoryContext type."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.types import ContextType
|
||||
from app.services.context.types.memory import MemoryContext, MemorySubtype
|
||||
|
||||
|
||||
class TestMemorySubtype:
|
||||
"""Tests for MemorySubtype enum."""
|
||||
|
||||
def test_all_types_defined(self) -> None:
|
||||
"""All memory subtypes should be defined."""
|
||||
assert MemorySubtype.WORKING == "working"
|
||||
assert MemorySubtype.EPISODIC == "episodic"
|
||||
assert MemorySubtype.SEMANTIC == "semantic"
|
||||
assert MemorySubtype.PROCEDURAL == "procedural"
|
||||
|
||||
def test_enum_values(self) -> None:
|
||||
"""Enum values should match strings."""
|
||||
assert MemorySubtype.WORKING.value == "working"
|
||||
assert MemorySubtype("episodic") == MemorySubtype.EPISODIC
|
||||
|
||||
|
||||
class TestMemoryContext:
|
||||
"""Tests for MemoryContext class."""
|
||||
|
||||
def test_get_type_returns_memory(self) -> None:
|
||||
"""get_type should return MEMORY."""
|
||||
ctx = MemoryContext(content="test", source="test_source")
|
||||
assert ctx.get_type() == ContextType.MEMORY
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default values should be set correctly."""
|
||||
ctx = MemoryContext(content="test", source="test_source")
|
||||
assert ctx.memory_subtype == MemorySubtype.EPISODIC
|
||||
assert ctx.memory_id is None
|
||||
assert ctx.relevance_score == 0.0
|
||||
assert ctx.importance == 0.5
|
||||
|
||||
def test_to_dict_includes_memory_fields(self) -> None:
|
||||
"""to_dict should include memory-specific fields."""
|
||||
ctx = MemoryContext(
|
||||
content="test content",
|
||||
source="test_source",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
memory_id="mem-123",
|
||||
relevance_score=0.8,
|
||||
subject="User",
|
||||
predicate="prefers",
|
||||
object_value="dark mode",
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
|
||||
assert data["memory_subtype"] == "semantic"
|
||||
assert data["memory_id"] == "mem-123"
|
||||
assert data["relevance_score"] == 0.8
|
||||
assert data["subject"] == "User"
|
||||
assert data["predicate"] == "prefers"
|
||||
assert data["object_value"] == "dark mode"
|
||||
|
||||
def test_from_dict(self) -> None:
|
||||
"""from_dict should create correct MemoryContext."""
|
||||
data = {
|
||||
"content": "test content",
|
||||
"source": "test_source",
|
||||
"timestamp": "2024-01-01T00:00:00+00:00",
|
||||
"memory_subtype": "semantic",
|
||||
"memory_id": "mem-123",
|
||||
"relevance_score": 0.8,
|
||||
"subject": "Test",
|
||||
}
|
||||
|
||||
ctx = MemoryContext.from_dict(data)
|
||||
|
||||
assert ctx.content == "test content"
|
||||
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert ctx.memory_id == "mem-123"
|
||||
assert ctx.subject == "Test"
|
||||
|
||||
|
||||
class TestMemoryContextFromWorkingMemory:
|
||||
"""Tests for MemoryContext.from_working_memory."""
|
||||
|
||||
def test_creates_working_memory_context(self) -> None:
|
||||
"""Should create working memory context from key/value."""
|
||||
ctx = MemoryContext.from_working_memory(
|
||||
key="user_preferences",
|
||||
value={"theme": "dark"},
|
||||
source="working:sess-123",
|
||||
query="preferences",
|
||||
)
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.WORKING
|
||||
assert ctx.key == "user_preferences"
|
||||
assert "{'theme': 'dark'}" in ctx.content
|
||||
assert ctx.relevance_score == 1.0 # Working memory is always relevant
|
||||
assert ctx.importance == 0.8 # Higher importance
|
||||
|
||||
def test_string_value(self) -> None:
|
||||
"""Should handle string values."""
|
||||
ctx = MemoryContext.from_working_memory(
|
||||
key="current_task",
|
||||
value="Build authentication",
|
||||
)
|
||||
|
||||
assert ctx.content == "Build authentication"
|
||||
|
||||
|
||||
class TestMemoryContextFromEpisodicMemory:
|
||||
"""Tests for MemoryContext.from_episodic_memory."""
|
||||
|
||||
def test_creates_episodic_memory_context(self) -> None:
|
||||
"""Should create episodic memory context from episode."""
|
||||
episode = MagicMock()
|
||||
episode.id = uuid4()
|
||||
episode.task_description = "Implemented login feature"
|
||||
episode.task_type = "feature_implementation"
|
||||
episode.outcome = MagicMock(value="success")
|
||||
episode.importance_score = 0.9
|
||||
episode.session_id = "sess-123"
|
||||
episode.occurred_at = datetime.now(UTC)
|
||||
episode.lessons_learned = ["Use proper validation"]
|
||||
|
||||
ctx = MemoryContext.from_episodic_memory(episode, query="login")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.EPISODIC
|
||||
assert ctx.memory_id == str(episode.id)
|
||||
assert ctx.content == "Implemented login feature"
|
||||
assert ctx.task_type == "feature_implementation"
|
||||
assert ctx.outcome == "success"
|
||||
assert ctx.importance == 0.9
|
||||
|
||||
def test_handles_missing_outcome(self) -> None:
|
||||
"""Should handle episodes with no outcome."""
|
||||
episode = MagicMock()
|
||||
episode.id = uuid4()
|
||||
episode.task_description = "WIP task"
|
||||
episode.outcome = None
|
||||
episode.importance_score = 0.5
|
||||
episode.occurred_at = None
|
||||
|
||||
ctx = MemoryContext.from_episodic_memory(episode)
|
||||
|
||||
assert ctx.outcome is None
|
||||
|
||||
|
||||
class TestMemoryContextFromSemanticMemory:
|
||||
"""Tests for MemoryContext.from_semantic_memory."""
|
||||
|
||||
def test_creates_semantic_memory_context(self) -> None:
|
||||
"""Should create semantic memory context from fact."""
|
||||
fact = MagicMock()
|
||||
fact.id = uuid4()
|
||||
fact.subject = "User"
|
||||
fact.predicate = "prefers"
|
||||
fact.object = "dark mode"
|
||||
fact.confidence = 0.95
|
||||
|
||||
ctx = MemoryContext.from_semantic_memory(fact, query="user preferences")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert ctx.memory_id == str(fact.id)
|
||||
assert ctx.content == "User prefers dark mode"
|
||||
assert ctx.subject == "User"
|
||||
assert ctx.predicate == "prefers"
|
||||
assert ctx.object_value == "dark mode"
|
||||
assert ctx.relevance_score == 0.95
|
||||
|
||||
|
||||
class TestMemoryContextFromProceduralMemory:
|
||||
"""Tests for MemoryContext.from_procedural_memory."""
|
||||
|
||||
def test_creates_procedural_memory_context(self) -> None:
|
||||
"""Should create procedural memory context from procedure."""
|
||||
procedure = MagicMock()
|
||||
procedure.id = uuid4()
|
||||
procedure.name = "Deploy to Production"
|
||||
procedure.trigger_pattern = "When deploying to production"
|
||||
procedure.steps = [
|
||||
{"action": "run_tests"},
|
||||
{"action": "build_docker"},
|
||||
{"action": "deploy"},
|
||||
]
|
||||
procedure.success_rate = 0.85
|
||||
procedure.success_count = 10
|
||||
procedure.failure_count = 2
|
||||
|
||||
ctx = MemoryContext.from_procedural_memory(procedure, query="deploy")
|
||||
|
||||
assert ctx.memory_subtype == MemorySubtype.PROCEDURAL
|
||||
assert ctx.memory_id == str(procedure.id)
|
||||
assert "Deploy to Production" in ctx.content
|
||||
assert "When deploying to production" in ctx.content
|
||||
assert ctx.trigger == "When deploying to production"
|
||||
assert ctx.success_rate == 0.85
|
||||
assert ctx.metadata["steps_count"] == 3
|
||||
assert ctx.metadata["execution_count"] == 12
|
||||
|
||||
|
||||
class TestMemoryContextHelpers:
|
||||
"""Tests for MemoryContext helper methods."""
|
||||
|
||||
def test_is_working_memory(self) -> None:
|
||||
"""is_working_memory should return True for working memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.WORKING,
|
||||
)
|
||||
assert ctx.is_working_memory() is True
|
||||
assert ctx.is_episodic_memory() is False
|
||||
|
||||
def test_is_episodic_memory(self) -> None:
|
||||
"""is_episodic_memory should return True for episodic memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
)
|
||||
assert ctx.is_episodic_memory() is True
|
||||
assert ctx.is_semantic_memory() is False
|
||||
|
||||
def test_is_semantic_memory(self) -> None:
|
||||
"""is_semantic_memory should return True for semantic memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.SEMANTIC,
|
||||
)
|
||||
assert ctx.is_semantic_memory() is True
|
||||
assert ctx.is_procedural_memory() is False
|
||||
|
||||
def test_is_procedural_memory(self) -> None:
|
||||
"""is_procedural_memory should return True for procedural memory."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="test",
|
||||
memory_subtype=MemorySubtype.PROCEDURAL,
|
||||
)
|
||||
assert ctx.is_procedural_memory() is True
|
||||
assert ctx.is_working_memory() is False
|
||||
|
||||
def test_get_formatted_source(self) -> None:
|
||||
"""get_formatted_source should return formatted string."""
|
||||
ctx = MemoryContext(
|
||||
content="test",
|
||||
source="episodic:12345678-1234-1234-1234-123456789012",
|
||||
memory_subtype=MemorySubtype.EPISODIC,
|
||||
memory_id="12345678-1234-1234-1234-123456789012",
|
||||
)
|
||||
|
||||
formatted = ctx.get_formatted_source()
|
||||
|
||||
assert "[episodic]" in formatted
|
||||
assert "12345678..." in formatted
|
||||
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/integration/__init__.py
|
||||
"""Tests for memory integration module."""
|
||||
@@ -0,0 +1,322 @@
|
||||
# tests/unit/services/memory/integration/test_context_source.py
|
||||
"""Tests for MemoryContextSource service."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.types.memory import MemorySubtype
|
||||
from app.services.memory.integration.context_source import (
|
||||
MemoryContextSource,
|
||||
MemoryFetchConfig,
|
||||
MemoryFetchResult,
|
||||
get_memory_context_source,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session() -> MagicMock:
|
||||
"""Create mock database session."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def context_source(mock_session: MagicMock) -> MemoryContextSource:
|
||||
"""Create MemoryContextSource instance."""
|
||||
return MemoryContextSource(session=mock_session)
|
||||
|
||||
|
||||
class TestMemoryFetchConfig:
|
||||
"""Tests for MemoryFetchConfig."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default config values should be set correctly."""
|
||||
config = MemoryFetchConfig()
|
||||
|
||||
assert config.working_limit == 10
|
||||
assert config.episodic_limit == 10
|
||||
assert config.semantic_limit == 15
|
||||
assert config.procedural_limit == 5
|
||||
assert config.episodic_days_back == 30
|
||||
assert config.min_relevance == 0.3
|
||||
assert config.include_working is True
|
||||
assert config.include_episodic is True
|
||||
assert config.include_semantic is True
|
||||
assert config.include_procedural is True
|
||||
|
||||
def test_custom_values(self) -> None:
|
||||
"""Custom config values should be respected."""
|
||||
config = MemoryFetchConfig(
|
||||
working_limit=5,
|
||||
include_working=False,
|
||||
)
|
||||
|
||||
assert config.working_limit == 5
|
||||
assert config.include_working is False
|
||||
|
||||
|
||||
class TestMemoryFetchResult:
|
||||
"""Tests for MemoryFetchResult."""
|
||||
|
||||
def test_stores_results(self) -> None:
|
||||
"""Result should store contexts and metadata."""
|
||||
result = MemoryFetchResult(
|
||||
contexts=[],
|
||||
by_type={"working": 0, "episodic": 5, "semantic": 3, "procedural": 0},
|
||||
fetch_time_ms=15.5,
|
||||
query="test query",
|
||||
)
|
||||
|
||||
assert result.contexts == []
|
||||
assert result.by_type["episodic"] == 5
|
||||
assert result.fetch_time_ms == 15.5
|
||||
assert result.query == "test query"
|
||||
|
||||
|
||||
class TestMemoryContextSource:
|
||||
"""Tests for MemoryContextSource service."""
|
||||
|
||||
async def test_fetch_context_empty_when_no_sources(
|
||||
self,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""fetch_context should return empty when all sources fail."""
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_episodic=False,
|
||||
include_semantic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="test",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert len(result.contexts) == 0
|
||||
assert result.by_type == {
|
||||
"working": 0,
|
||||
"episodic": 0,
|
||||
"semantic": 0,
|
||||
"procedural": 0,
|
||||
}
|
||||
|
||||
@patch("app.services.memory.integration.context_source.WorkingMemory")
|
||||
async def test_fetch_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch working memory when session_id provided."""
|
||||
# Setup mock - both keys should match the query "task"
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["current_task", "task_state"])
|
||||
mock_working.get = AsyncMock(side_effect=lambda k: {"key": k, "value": "test"})
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_episodic=False,
|
||||
include_semantic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="task", # Both keys contain "task"
|
||||
project_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["working"] == 2
|
||||
assert all(
|
||||
c.memory_subtype == MemorySubtype.WORKING for c in result.contexts
|
||||
)
|
||||
|
||||
@patch("app.services.memory.integration.context_source.EpisodicMemory")
|
||||
async def test_fetch_episodic_memory(
|
||||
self,
|
||||
mock_episodic_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch episodic memory."""
|
||||
# Setup mock episode
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
mock_episode.task_description = "Completed login feature"
|
||||
mock_episode.task_type = "feature"
|
||||
mock_episode.outcome = MagicMock(value="success")
|
||||
mock_episode.importance_score = 0.8
|
||||
mock_episode.occurred_at = datetime.now(UTC)
|
||||
mock_episode.lessons_learned = []
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.search_similar = AsyncMock(return_value=[mock_episode])
|
||||
mock_episodic.get_recent = AsyncMock(return_value=[])
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_semantic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="login",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["episodic"] == 1
|
||||
assert result.contexts[0].memory_subtype == MemorySubtype.EPISODIC
|
||||
assert "Completed login feature" in result.contexts[0].content
|
||||
|
||||
@patch("app.services.memory.integration.context_source.SemanticMemory")
|
||||
async def test_fetch_semantic_memory(
|
||||
self,
|
||||
mock_semantic_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch semantic memory."""
|
||||
# Setup mock fact
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.id = uuid4()
|
||||
mock_fact.subject = "User"
|
||||
mock_fact.predicate = "prefers"
|
||||
mock_fact.object = "dark mode"
|
||||
mock_fact.confidence = 0.9
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.search_facts = AsyncMock(return_value=[mock_fact])
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_episodic=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="preferences",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["semantic"] == 1
|
||||
assert result.contexts[0].memory_subtype == MemorySubtype.SEMANTIC
|
||||
assert "User prefers dark mode" in result.contexts[0].content
|
||||
|
||||
@patch("app.services.memory.integration.context_source.ProceduralMemory")
|
||||
async def test_fetch_procedural_memory(
|
||||
self,
|
||||
mock_procedural_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Should fetch procedural memory."""
|
||||
# Setup mock procedure
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.id = uuid4()
|
||||
mock_proc.name = "Deploy"
|
||||
mock_proc.trigger_pattern = "When deploying"
|
||||
mock_proc.steps = [{"action": "build"}, {"action": "test"}]
|
||||
mock_proc.success_rate = 0.9
|
||||
mock_proc.success_count = 9
|
||||
mock_proc.failure_count = 1
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching = AsyncMock(return_value=[mock_proc])
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_episodic=False,
|
||||
include_semantic=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="deploy",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert result.by_type["procedural"] == 1
|
||||
assert result.contexts[0].memory_subtype == MemorySubtype.PROCEDURAL
|
||||
assert "Deploy" in result.contexts[0].content
|
||||
|
||||
async def test_results_sorted_by_relevance(
|
||||
self,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""Results should be sorted by relevance score."""
|
||||
with patch.object(
|
||||
context_source, "_fetch_episodic"
|
||||
) as mock_ep, patch.object(
|
||||
context_source, "_fetch_semantic"
|
||||
) as mock_sem:
|
||||
# Create contexts with different relevance scores
|
||||
from app.services.context.types.memory import MemoryContext
|
||||
|
||||
ctx_low = MemoryContext(
|
||||
content="low relevance",
|
||||
source="test",
|
||||
relevance_score=0.3,
|
||||
)
|
||||
ctx_high = MemoryContext(
|
||||
content="high relevance",
|
||||
source="test",
|
||||
relevance_score=0.9,
|
||||
)
|
||||
|
||||
mock_ep.return_value = [ctx_low]
|
||||
mock_sem.return_value = [ctx_high]
|
||||
|
||||
config = MemoryFetchConfig(
|
||||
include_working=False,
|
||||
include_procedural=False,
|
||||
)
|
||||
|
||||
result = await context_source.fetch_context(
|
||||
query="test",
|
||||
project_id=uuid4(),
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Higher relevance should come first
|
||||
assert result.contexts[0].relevance_score == 0.9
|
||||
assert result.contexts[1].relevance_score == 0.3
|
||||
|
||||
@patch("app.services.memory.integration.context_source.WorkingMemory")
|
||||
async def test_fetch_all_working(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
context_source: MemoryContextSource,
|
||||
) -> None:
|
||||
"""fetch_all_working should return all working memory items."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2", "key3"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
contexts = await context_source.fetch_all_working(
|
||||
session_id="sess-123",
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
assert len(contexts) == 3
|
||||
assert all(c.memory_subtype == MemorySubtype.WORKING for c in contexts)
|
||||
|
||||
|
||||
class TestGetMemoryContextSource:
|
||||
"""Tests for factory function."""
|
||||
|
||||
async def test_creates_instance(self) -> None:
|
||||
"""Factory should create MemoryContextSource instance."""
|
||||
mock_session = MagicMock()
|
||||
|
||||
source = await get_memory_context_source(mock_session)
|
||||
|
||||
assert isinstance(source, MemoryContextSource)
|
||||
471
backend/tests/unit/services/memory/integration/test_lifecycle.py
Normal file
471
backend/tests/unit/services/memory/integration/test_lifecycle.py
Normal file
@@ -0,0 +1,471 @@
|
||||
# tests/unit/services/memory/integration/test_lifecycle.py
|
||||
"""Tests for Agent Lifecycle Hooks."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.integration.lifecycle import (
|
||||
AgentLifecycleManager,
|
||||
LifecycleEvent,
|
||||
LifecycleHooks,
|
||||
LifecycleResult,
|
||||
get_lifecycle_manager,
|
||||
)
|
||||
from app.services.memory.types import Outcome
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session() -> MagicMock:
|
||||
"""Create mock database session."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lifecycle_hooks() -> LifecycleHooks:
|
||||
"""Create lifecycle hooks instance."""
|
||||
return LifecycleHooks()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lifecycle_manager(mock_session: MagicMock) -> AgentLifecycleManager:
|
||||
"""Create lifecycle manager instance."""
|
||||
return AgentLifecycleManager(session=mock_session)
|
||||
|
||||
|
||||
class TestLifecycleEvent:
|
||||
"""Tests for LifecycleEvent dataclass."""
|
||||
|
||||
def test_creates_event(self) -> None:
|
||||
"""Should create event with required fields."""
|
||||
project_id = uuid4()
|
||||
agent_id = uuid4()
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_id,
|
||||
)
|
||||
|
||||
assert event.event_type == "spawn"
|
||||
assert event.project_id == project_id
|
||||
assert event.agent_instance_id == agent_id
|
||||
assert event.timestamp is not None
|
||||
assert event.metadata == {}
|
||||
|
||||
def test_with_optional_fields(self) -> None:
|
||||
"""Should include optional fields."""
|
||||
event = LifecycleEvent(
|
||||
event_type="terminate",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
metadata={"reason": "completed"},
|
||||
)
|
||||
|
||||
assert event.session_id == "sess-123"
|
||||
assert event.metadata["reason"] == "completed"
|
||||
|
||||
|
||||
class TestLifecycleResult:
|
||||
"""Tests for LifecycleResult dataclass."""
|
||||
|
||||
def test_success_result(self) -> None:
|
||||
"""Should create success result."""
|
||||
result = LifecycleResult(
|
||||
success=True,
|
||||
event_type="spawn",
|
||||
message="Agent spawned",
|
||||
data={"session_id": "sess-123"},
|
||||
duration_ms=10.5,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "spawn"
|
||||
assert result.data["session_id"] == "sess-123"
|
||||
|
||||
def test_failure_result(self) -> None:
|
||||
"""Should create failure result."""
|
||||
result = LifecycleResult(
|
||||
success=False,
|
||||
event_type="resume",
|
||||
message="Checkpoint not found",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert result.message == "Checkpoint not found"
|
||||
|
||||
|
||||
class TestLifecycleHooks:
|
||||
"""Tests for LifecycleHooks class."""
|
||||
|
||||
def test_register_spawn_hook(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should register spawn hook."""
|
||||
async def my_hook(event: LifecycleEvent) -> None:
|
||||
pass
|
||||
|
||||
result = lifecycle_hooks.on_spawn(my_hook)
|
||||
|
||||
assert result is my_hook
|
||||
assert my_hook in lifecycle_hooks._spawn_hooks
|
||||
|
||||
def test_register_all_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should register hooks for all event types."""
|
||||
hooks = [
|
||||
lifecycle_hooks.on_spawn(AsyncMock()),
|
||||
lifecycle_hooks.on_pause(AsyncMock()),
|
||||
lifecycle_hooks.on_resume(AsyncMock()),
|
||||
lifecycle_hooks.on_terminate(AsyncMock()),
|
||||
]
|
||||
|
||||
assert len(lifecycle_hooks._spawn_hooks) == 1
|
||||
assert len(lifecycle_hooks._pause_hooks) == 1
|
||||
assert len(lifecycle_hooks._resume_hooks) == 1
|
||||
assert len(lifecycle_hooks._terminate_hooks) == 1
|
||||
|
||||
async def test_run_spawn_hooks(self, lifecycle_hooks: LifecycleHooks) -> None:
|
||||
"""Should run all spawn hooks."""
|
||||
hook1 = AsyncMock()
|
||||
hook2 = AsyncMock()
|
||||
lifecycle_hooks.on_spawn(hook1)
|
||||
lifecycle_hooks.on_spawn(hook2)
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="spawn",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
)
|
||||
|
||||
await lifecycle_hooks.run_spawn_hooks(event)
|
||||
|
||||
hook1.assert_called_once_with(event)
|
||||
hook2.assert_called_once_with(event)
|
||||
|
||||
async def test_hook_failure_doesnt_stop_others(
|
||||
self, lifecycle_hooks: LifecycleHooks
|
||||
) -> None:
|
||||
"""Hook failure should not stop other hooks from running."""
|
||||
hook1 = AsyncMock(side_effect=ValueError("Oops"))
|
||||
hook2 = AsyncMock()
|
||||
lifecycle_hooks.on_pause(hook1)
|
||||
lifecycle_hooks.on_pause(hook2)
|
||||
|
||||
event = LifecycleEvent(
|
||||
event_type="pause",
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
)
|
||||
|
||||
await lifecycle_hooks.run_pause_hooks(event)
|
||||
|
||||
# hook2 should still be called even though hook1 failed
|
||||
hook2.assert_called_once()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerSpawn:
|
||||
"""Tests for AgentLifecycleManager.spawn."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_creates_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should create working memory for session."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "spawn"
|
||||
mock_working_cls.for_session.assert_called_once()
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_with_initial_state(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should populate initial state."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
initial_state={"key1": "value1", "key2": "value2"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["initial_items"] == 2
|
||||
assert mock_working.set.call_count == 2
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_spawn_runs_hooks(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Spawn should run registered hooks."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
hook = AsyncMock()
|
||||
lifecycle_manager.hooks.on_spawn(hook)
|
||||
|
||||
await lifecycle_manager.spawn(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
hook.assert_called_once()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerPause:
|
||||
"""Tests for AgentLifecycleManager.pause."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_pause_creates_checkpoint(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Pause should create checkpoint of working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value={"data": "test"})
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.pause(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="ckpt-001",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "pause"
|
||||
assert result.data["checkpoint_id"] == "ckpt-001"
|
||||
assert result.data["items_saved"] == 2
|
||||
|
||||
# Should save checkpoint with state
|
||||
mock_working.set.assert_called_once()
|
||||
call_args = mock_working.set.call_args
|
||||
# Check positional arg (first arg is key)
|
||||
assert "__checkpoint__ckpt-001" in call_args[0][0]
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_pause_generates_checkpoint_id(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Pause should generate checkpoint ID if not provided."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=[])
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.pause(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert "checkpoint_id" in result.data
|
||||
assert result.data["checkpoint_id"].startswith("checkpoint_")
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerResume:
|
||||
"""Tests for AgentLifecycleManager.resume."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_resume_restores_checkpoint(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Resume should restore working memory from checkpoint."""
|
||||
checkpoint_data = {
|
||||
"state": {"key1": "value1", "key2": "value2"},
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"keys_count": 2,
|
||||
}
|
||||
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=[])
|
||||
mock_working.get = AsyncMock(return_value=checkpoint_data)
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.resume(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="ckpt-001",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "resume"
|
||||
assert result.data["items_restored"] == 2
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_resume_checkpoint_not_found(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Resume should fail if checkpoint not found."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.get = AsyncMock(return_value=None)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.resume(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
checkpoint_id="nonexistent",
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "not found" in result.message.lower()
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerTerminate:
|
||||
"""Tests for AgentLifecycleManager.terminate."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.EpisodicMemory")
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_terminate_consolidates_to_episodic(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
mock_episodic_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Terminate should consolidate working memory to episodic."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await lifecycle_manager.terminate(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
task_description="Completed task",
|
||||
outcome=Outcome.SUCCESS,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.event_type == "terminate"
|
||||
assert result.data["episode_id"] == str(mock_episode.id)
|
||||
assert result.data["state_items_consolidated"] == 2
|
||||
mock_episodic.record_episode.assert_called_once()
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_terminate_cleans_up_working(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Terminate should clean up working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working.get = AsyncMock(return_value="value")
|
||||
mock_working.delete = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await lifecycle_manager.terminate(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
consolidate_to_episodic=False,
|
||||
cleanup_working=True,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["items_cleared"] == 2
|
||||
assert mock_working.delete.call_count == 2
|
||||
|
||||
|
||||
class TestAgentLifecycleManagerListCheckpoints:
|
||||
"""Tests for AgentLifecycleManager.list_checkpoints."""
|
||||
|
||||
@patch("app.services.memory.integration.lifecycle.WorkingMemory")
|
||||
async def test_list_checkpoints(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
lifecycle_manager: AgentLifecycleManager,
|
||||
) -> None:
|
||||
"""Should list available checkpoints."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(
|
||||
return_value=[
|
||||
"__checkpoint__ckpt-001",
|
||||
"__checkpoint__ckpt-002",
|
||||
"regular_key",
|
||||
]
|
||||
)
|
||||
mock_working.get = AsyncMock(
|
||||
return_value={
|
||||
"timestamp": "2024-01-01T00:00:00Z",
|
||||
"keys_count": 5,
|
||||
}
|
||||
)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
checkpoints = await lifecycle_manager.list_checkpoints(
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=uuid4(),
|
||||
session_id="sess-123",
|
||||
)
|
||||
|
||||
assert len(checkpoints) == 2
|
||||
assert checkpoints[0]["checkpoint_id"] == "ckpt-001"
|
||||
assert checkpoints[0]["keys_count"] == 5
|
||||
|
||||
|
||||
class TestGetLifecycleManager:
|
||||
"""Tests for factory function."""
|
||||
|
||||
async def test_creates_instance(self) -> None:
|
||||
"""Factory should create AgentLifecycleManager instance."""
|
||||
mock_session = MagicMock()
|
||||
|
||||
manager = await get_lifecycle_manager(mock_session)
|
||||
|
||||
assert isinstance(manager, AgentLifecycleManager)
|
||||
|
||||
async def test_with_custom_hooks(self) -> None:
|
||||
"""Factory should accept custom hooks."""
|
||||
mock_session = MagicMock()
|
||||
custom_hooks = LifecycleHooks()
|
||||
|
||||
manager = await get_lifecycle_manager(mock_session, hooks=custom_hooks)
|
||||
|
||||
assert manager.hooks is custom_hooks
|
||||
Reference in New Issue
Block a user