feat(memory): add procedural memory implementation (Issue #92)

Implements procedural memory for learned skills and procedures:

Core functionality:
- ProceduralMemory class for procedure storage/retrieval
- record_procedure with duplicate detection and step merging
- find_matching for context-based procedure search
- record_outcome for success/failure tracking
- get_best_procedure for finding highest success rate
- update_steps for procedure refinement

Supporting modules:
- ProcedureMatcher: Keyword-based procedure matching
- MatchResult/MatchContext: Matching result types
- Success rate weighting in match scoring

Test coverage:
- 43 unit tests covering all modules
- matching.py: 97% coverage
- memory.py: 86% coverage

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 02:31:32 +01:00
parent e946787a61
commit b818f17418
6 changed files with 2029 additions and 1 deletions

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/procedural/__init__.py
"""Unit tests for procedural memory."""

View File

@@ -0,0 +1,427 @@
# tests/unit/services/memory/procedural/test_matching.py
"""Unit tests for procedure matching."""
from datetime import UTC, datetime
from uuid import uuid4
import pytest
from app.services.memory.procedural.matching import (
MatchContext,
MatchResult,
ProcedureMatcher,
get_procedure_matcher,
)
from app.services.memory.types import Procedure
def create_test_procedure(
name: str = "deploy_api",
trigger_pattern: str = "deploy.*api",
success_count: int = 8,
failure_count: int = 2,
) -> Procedure:
"""Create a test procedure for testing."""
now = datetime.now(UTC)
return Procedure(
id=uuid4(),
project_id=None,
agent_type_id=None,
name=name,
trigger_pattern=trigger_pattern,
steps=[
{"order": 1, "action": "build"},
{"order": 2, "action": "test"},
{"order": 3, "action": "deploy"},
],
success_count=success_count,
failure_count=failure_count,
last_used=now,
embedding=None,
created_at=now,
updated_at=now,
)
class TestMatchResult:
"""Tests for MatchResult dataclass."""
def test_to_dict(self) -> None:
"""Test converting match result to dictionary."""
procedure = create_test_procedure()
result = MatchResult(
procedure=procedure,
score=0.85,
matched_terms=["deploy", "api"],
match_type="keyword",
)
data = result.to_dict()
assert "procedure_id" in data
assert "procedure_name" in data
assert data["score"] == 0.85
assert data["matched_terms"] == ["deploy", "api"]
assert data["match_type"] == "keyword"
assert data["success_rate"] == 0.8
class TestMatchContext:
"""Tests for MatchContext dataclass."""
def test_default_values(self) -> None:
"""Test default values."""
context = MatchContext(query="deploy api")
assert context.query == "deploy api"
assert context.task_type is None
assert context.project_id is None
assert context.max_results == 5
assert context.min_score == 0.3
assert context.require_success_rate is None
def test_with_all_values(self) -> None:
"""Test with all values set."""
project_id = uuid4()
context = MatchContext(
query="deploy api",
task_type="deployment",
project_id=project_id,
max_results=10,
min_score=0.5,
require_success_rate=0.7,
)
assert context.query == "deploy api"
assert context.task_type == "deployment"
assert context.project_id == project_id
assert context.max_results == 10
assert context.min_score == 0.5
assert context.require_success_rate == 0.7
class TestProcedureMatcher:
"""Tests for ProcedureMatcher class."""
@pytest.fixture
def matcher(self) -> ProcedureMatcher:
"""Create a procedure matcher."""
return ProcedureMatcher()
@pytest.fixture
def procedures(self) -> list[Procedure]:
"""Create test procedures."""
return [
create_test_procedure(
name="deploy_api",
trigger_pattern="deploy.*api",
success_count=9,
failure_count=1,
),
create_test_procedure(
name="deploy_frontend",
trigger_pattern="deploy.*frontend",
success_count=7,
failure_count=3,
),
create_test_procedure(
name="build_project",
trigger_pattern="build.*project",
success_count=8,
failure_count=2,
),
create_test_procedure(
name="run_tests",
trigger_pattern="test.*run",
success_count=5,
failure_count=5,
),
]
def test_match_exact_name(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test matching with exact name."""
context = MatchContext(query="deploy_api")
results = matcher.match(procedures, context)
assert len(results) > 0
# First result should be deploy_api
assert results[0].procedure.name == "deploy_api"
def test_match_partial_terms(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test matching with partial terms."""
context = MatchContext(query="deploy")
results = matcher.match(procedures, context)
assert len(results) >= 2
# Both deploy procedures should match
names = [r.procedure.name for r in results]
assert "deploy_api" in names
assert "deploy_frontend" in names
def test_match_with_task_type(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test matching with task type."""
context = MatchContext(
query="build something",
task_type="build",
)
results = matcher.match(procedures, context)
assert len(results) > 0
assert results[0].procedure.name == "build_project"
def test_match_respects_min_score(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test that matching respects minimum score."""
context = MatchContext(
query="completely unrelated query xyz",
min_score=0.5,
)
results = matcher.match(procedures, context)
# Should not match anything with high min_score
for result in results:
assert result.score >= 0.5
def test_match_respects_success_rate_requirement(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test that matching respects success rate requirement."""
context = MatchContext(
query="deploy",
require_success_rate=0.7,
)
results = matcher.match(procedures, context)
for result in results:
assert result.procedure.success_rate >= 0.7
def test_match_respects_max_results(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test that matching respects max results."""
context = MatchContext(
query="deploy",
max_results=1,
min_score=0.0,
)
results = matcher.match(procedures, context)
assert len(results) <= 1
def test_match_sorts_by_score(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test that results are sorted by score."""
context = MatchContext(query="deploy", min_score=0.0)
results = matcher.match(procedures, context)
if len(results) > 1:
scores = [r.score for r in results]
assert scores == sorted(scores, reverse=True)
def test_match_empty_procedures(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test matching with empty procedures list."""
context = MatchContext(query="deploy")
results = matcher.match([], context)
assert results == []
class TestProcedureMatcherRankByRelevance:
"""Tests for rank_by_relevance method."""
@pytest.fixture
def matcher(self) -> ProcedureMatcher:
"""Create a procedure matcher."""
return ProcedureMatcher()
def test_rank_by_relevance(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test ranking by relevance."""
procedures = [
create_test_procedure(name="unrelated", trigger_pattern="something else"),
create_test_procedure(name="deploy_api", trigger_pattern="deploy.*api"),
create_test_procedure(
name="deploy_frontend", trigger_pattern="deploy.*frontend"
),
]
ranked = matcher.rank_by_relevance(procedures, "deploy")
# Deploy procedures should be ranked first
assert ranked[0].name in ["deploy_api", "deploy_frontend"]
def test_rank_by_relevance_empty(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test ranking empty list."""
ranked = matcher.rank_by_relevance([], "deploy")
assert ranked == []
class TestProcedureMatcherSuggestProcedures:
"""Tests for suggest_procedures method."""
@pytest.fixture
def matcher(self) -> ProcedureMatcher:
"""Create a procedure matcher."""
return ProcedureMatcher()
@pytest.fixture
def procedures(self) -> list[Procedure]:
"""Create test procedures."""
return [
create_test_procedure(
name="deploy_api",
trigger_pattern="deploy api",
success_count=9,
failure_count=1,
),
create_test_procedure(
name="bad_deploy",
trigger_pattern="deploy bad",
success_count=2,
failure_count=8,
),
]
def test_suggest_procedures(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test suggesting procedures."""
suggestions = matcher.suggest_procedures(
procedures,
"deploy",
min_success_rate=0.5,
)
assert len(suggestions) > 0
# Only high success rate should be suggested
for s in suggestions:
assert s.procedure.success_rate >= 0.5
def test_suggest_procedures_limits_results(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test that suggestions are limited."""
suggestions = matcher.suggest_procedures(
procedures,
"deploy",
max_suggestions=1,
)
assert len(suggestions) <= 1
class TestGetProcedureMatcher:
"""Tests for singleton getter."""
def test_get_procedure_matcher_returns_instance(self) -> None:
"""Test that getter returns instance."""
matcher = get_procedure_matcher()
assert matcher is not None
assert isinstance(matcher, ProcedureMatcher)
def test_get_procedure_matcher_returns_same_instance(self) -> None:
"""Test that getter returns same instance (singleton)."""
matcher1 = get_procedure_matcher()
matcher2 = get_procedure_matcher()
assert matcher1 is matcher2
class TestProcedureMatcherExtractTerms:
"""Tests for term extraction."""
@pytest.fixture
def matcher(self) -> ProcedureMatcher:
"""Create a procedure matcher."""
return ProcedureMatcher()
def test_extract_terms_basic(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test basic term extraction."""
terms = matcher._extract_terms("deploy the api")
assert "deploy" in terms
assert "the" in terms
assert "api" in terms
def test_extract_terms_removes_special_chars(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test that special characters are removed."""
terms = matcher._extract_terms("deploy.api!now")
assert "deploy" in terms
assert "api" in terms
assert "now" in terms
assert "." not in terms
assert "!" not in terms
def test_extract_terms_filters_short(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test that short terms are filtered."""
terms = matcher._extract_terms("a big api")
assert "a" not in terms
assert "big" in terms
assert "api" in terms
def test_extract_terms_lowercases(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test that terms are lowercased."""
terms = matcher._extract_terms("Deploy API")
assert "deploy" in terms
assert "api" in terms
assert "Deploy" not in terms
assert "API" not in terms

View File

@@ -0,0 +1,569 @@
# tests/unit/services/memory/procedural/test_memory.py
"""Unit tests for ProceduralMemory class."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.services.memory.procedural.memory import ProceduralMemory
from app.services.memory.types import ProcedureCreate, Step
def create_mock_procedure_model(
name="deploy_api",
trigger_pattern="deploy.*api",
project_id=None,
agent_type_id=None,
success_count=5,
failure_count=1,
):
"""Create a mock procedure model for testing."""
mock = MagicMock()
mock.id = uuid4()
mock.project_id = project_id
mock.agent_type_id = agent_type_id
mock.name = name
mock.trigger_pattern = trigger_pattern
mock.steps = [
{"order": 1, "action": "build", "parameters": {}},
{"order": 2, "action": "test", "parameters": {}},
{"order": 3, "action": "deploy", "parameters": {}},
]
mock.success_count = success_count
mock.failure_count = failure_count
mock.last_used = datetime.now(UTC)
mock.embedding = None
mock.created_at = datetime.now(UTC)
mock.updated_at = datetime.now(UTC)
mock.success_rate = (
success_count / (success_count + failure_count)
if (success_count + failure_count) > 0
else 0.0
)
mock.total_uses = success_count + failure_count
return mock
class TestProceduralMemoryInit:
"""Tests for ProceduralMemory initialization."""
def test_init_creates_memory(self) -> None:
"""Test that init creates memory instance."""
mock_session = AsyncMock()
memory = ProceduralMemory(session=mock_session)
assert memory._session is mock_session
def test_init_with_embedding_generator(self) -> None:
"""Test init with embedding generator."""
mock_session = AsyncMock()
mock_embedding_gen = AsyncMock()
memory = ProceduralMemory(
session=mock_session, embedding_generator=mock_embedding_gen
)
assert memory._embedding_generator is mock_embedding_gen
@pytest.mark.asyncio
async def test_create_factory_method(self) -> None:
"""Test create factory method."""
mock_session = AsyncMock()
memory = await ProceduralMemory.create(session=mock_session)
assert memory is not None
assert memory._session is mock_session
class TestProceduralMemoryRecordProcedure:
"""Tests for procedure recording methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
session.add = MagicMock()
session.flush = AsyncMock()
session.refresh = AsyncMock()
# Mock no existing procedure found
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
session.execute.return_value = mock_result
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_record_new_procedure(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test recording a new procedure."""
procedure_data = ProcedureCreate(
name="build_project",
trigger_pattern="build.*project",
steps=[
{"order": 1, "action": "npm install"},
{"order": 2, "action": "npm run build"},
],
project_id=uuid4(),
)
result = await memory.record_procedure(procedure_data)
assert result.name == "build_project"
assert result.trigger_pattern == "build.*project"
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
@pytest.mark.asyncio
async def test_record_updates_existing(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test that recording duplicate procedure updates existing."""
# Mock existing procedure found
existing_mock = create_mock_procedure_model()
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
# Mock update result
updated_mock = create_mock_procedure_model(success_count=6)
update_result = MagicMock()
update_result.scalar_one.return_value = updated_mock
mock_session.execute.side_effect = [find_result, update_result]
procedure_data = ProcedureCreate(
name="deploy_api",
trigger_pattern="deploy.*api",
steps=[{"order": 1, "action": "deploy"}],
)
_ = await memory.record_procedure(procedure_data)
# Should have called execute twice (find + update)
assert mock_session.execute.call_count == 2
class TestProceduralMemoryFindMatching:
"""Tests for procedure matching methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
session.execute.return_value = mock_result
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_find_matching(
self,
memory: ProceduralMemory,
) -> None:
"""Test finding matching procedures."""
results = await memory.find_matching("deploy api")
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_find_matching_with_project_filter(
self,
memory: ProceduralMemory,
) -> None:
"""Test finding matching procedures with project filter."""
project_id = uuid4()
results = await memory.find_matching(
"deploy api",
project_id=project_id,
)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_find_matching_returns_results(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test that find_matching returns results."""
procedures = [
create_mock_procedure_model(name="deploy_api"),
create_mock_procedure_model(name="deploy_frontend"),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = procedures
mock_session.execute.return_value = mock_result
results = await memory.find_matching("deploy")
assert len(results) == 2
class TestProceduralMemoryGetBestProcedure:
"""Tests for get_best_procedure method."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_best_procedure_none(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test get_best_procedure returns None when no match."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
result = await memory.get_best_procedure("unknown_task")
assert result is None
@pytest.mark.asyncio
async def test_get_best_procedure_returns_highest_success_rate(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test get_best_procedure returns highest success rate."""
low_success = create_mock_procedure_model(
name="deploy_v1", success_count=3, failure_count=7
)
high_success = create_mock_procedure_model(
name="deploy_v2", success_count=9, failure_count=1
)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [high_success, low_success]
mock_session.execute.return_value = mock_result
result = await memory.get_best_procedure("deploy")
assert result is not None
assert result.name == "deploy_v2"
class TestProceduralMemoryRecordOutcome:
"""Tests for outcome recording."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_record_outcome_success(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test recording a successful outcome."""
existing_mock = create_mock_procedure_model()
# First query: find procedure
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
# Second query: update
updated_mock = create_mock_procedure_model(success_count=6)
update_result = MagicMock()
update_result.scalar_one.return_value = updated_mock
mock_session.execute.side_effect = [find_result, update_result]
result = await memory.record_outcome(existing_mock.id, success=True)
assert result.success_count == 6
@pytest.mark.asyncio
async def test_record_outcome_failure(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test recording a failure outcome."""
existing_mock = create_mock_procedure_model()
# First query: find procedure
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
# Second query: update
updated_mock = create_mock_procedure_model(failure_count=2)
update_result = MagicMock()
update_result.scalar_one.return_value = updated_mock
mock_session.execute.side_effect = [find_result, update_result]
result = await memory.record_outcome(existing_mock.id, success=False)
assert result.failure_count == 2
@pytest.mark.asyncio
async def test_record_outcome_not_found(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test recording outcome for non-existent procedure raises error."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
with pytest.raises(ValueError, match="Procedure not found"):
await memory.record_outcome(uuid4(), success=True)
class TestProceduralMemoryUpdateSteps:
"""Tests for step updates."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_update_steps(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test updating steps."""
existing_mock = create_mock_procedure_model()
# First query: find procedure
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
# Second query: update
updated_mock = create_mock_procedure_model()
updated_mock.steps = [
{"order": 1, "action": "new_step_1", "parameters": {}},
{"order": 2, "action": "new_step_2", "parameters": {}},
]
update_result = MagicMock()
update_result.scalar_one.return_value = updated_mock
mock_session.execute.side_effect = [find_result, update_result]
new_steps = [
Step(order=1, action="new_step_1"),
Step(order=2, action="new_step_2"),
]
result = await memory.update_steps(existing_mock.id, new_steps)
assert len(result.steps) == 2
@pytest.mark.asyncio
async def test_update_steps_not_found(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test updating steps for non-existent procedure raises error."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
with pytest.raises(ValueError, match="Procedure not found"):
await memory.update_steps(uuid4(), [Step(order=1, action="test")])
class TestProceduralMemoryStats:
"""Tests for statistics methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_stats_empty(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test getting stats for empty project."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
stats = await memory.get_stats(uuid4())
assert stats["total_procedures"] == 0
assert stats["avg_success_rate"] == 0.0
@pytest.mark.asyncio
async def test_get_stats_with_data(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test getting stats with data."""
procedures = [
create_mock_procedure_model(success_count=8, failure_count=2),
create_mock_procedure_model(success_count=6, failure_count=4),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = procedures
mock_session.execute.return_value = mock_result
stats = await memory.get_stats(uuid4())
assert stats["total_procedures"] == 2
assert stats["total_uses"] == 20 # (8+2) + (6+4)
@pytest.mark.asyncio
async def test_count(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test counting procedures."""
procedures = [create_mock_procedure_model() for _ in range(5)]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = procedures
mock_session.execute.return_value = mock_result
count = await memory.count(uuid4())
assert count == 5
class TestProceduralMemoryDelete:
"""Tests for delete operations."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_delete(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test deleting a procedure."""
existing_mock = create_mock_procedure_model()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_mock
mock_session.execute.return_value = mock_result
mock_session.delete = AsyncMock()
result = await memory.delete(existing_mock.id)
assert result is True
mock_session.delete.assert_called_once()
@pytest.mark.asyncio
async def test_delete_not_found(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test deleting non-existent procedure."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.delete(uuid4())
assert result is False
class TestProceduralMemoryGetById:
"""Tests for get_by_id method."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_by_id(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test getting procedure by ID."""
existing_mock = create_mock_procedure_model()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_mock
mock_session.execute.return_value = mock_result
result = await memory.get_by_id(existing_mock.id)
assert result is not None
assert result.name == "deploy_api"
@pytest.mark.asyncio
async def test_get_by_id_not_found(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test get_by_id returns None when not found."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.get_by_id(uuid4())
assert result is None