feat(memory): add procedural memory implementation (Issue #92)
Implements procedural memory for learned skills and procedures: Core functionality: - ProceduralMemory class for procedure storage/retrieval - record_procedure with duplicate detection and step merging - find_matching for context-based procedure search - record_outcome for success/failure tracking - get_best_procedure for finding highest success rate - update_steps for procedure refinement Supporting modules: - ProcedureMatcher: Keyword-based procedure matching - MatchResult/MatchContext: Matching result types - Success rate weighting in match scoring Test coverage: - 43 unit tests covering all modules - matching.py: 97% coverage - memory.py: 86% coverage 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
427
backend/tests/unit/services/memory/procedural/test_matching.py
Normal file
427
backend/tests/unit/services/memory/procedural/test_matching.py
Normal file
@@ -0,0 +1,427 @@
|
||||
# tests/unit/services/memory/procedural/test_matching.py
|
||||
"""Unit tests for procedure matching."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.procedural.matching import (
|
||||
MatchContext,
|
||||
MatchResult,
|
||||
ProcedureMatcher,
|
||||
get_procedure_matcher,
|
||||
)
|
||||
from app.services.memory.types import Procedure
|
||||
|
||||
|
||||
def create_test_procedure(
|
||||
name: str = "deploy_api",
|
||||
trigger_pattern: str = "deploy.*api",
|
||||
success_count: int = 8,
|
||||
failure_count: int = 2,
|
||||
) -> Procedure:
|
||||
"""Create a test procedure for testing."""
|
||||
now = datetime.now(UTC)
|
||||
return Procedure(
|
||||
id=uuid4(),
|
||||
project_id=None,
|
||||
agent_type_id=None,
|
||||
name=name,
|
||||
trigger_pattern=trigger_pattern,
|
||||
steps=[
|
||||
{"order": 1, "action": "build"},
|
||||
{"order": 2, "action": "test"},
|
||||
{"order": 3, "action": "deploy"},
|
||||
],
|
||||
success_count=success_count,
|
||||
failure_count=failure_count,
|
||||
last_used=now,
|
||||
embedding=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
class TestMatchResult:
|
||||
"""Tests for MatchResult dataclass."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting match result to dictionary."""
|
||||
procedure = create_test_procedure()
|
||||
result = MatchResult(
|
||||
procedure=procedure,
|
||||
score=0.85,
|
||||
matched_terms=["deploy", "api"],
|
||||
match_type="keyword",
|
||||
)
|
||||
|
||||
data = result.to_dict()
|
||||
|
||||
assert "procedure_id" in data
|
||||
assert "procedure_name" in data
|
||||
assert data["score"] == 0.85
|
||||
assert data["matched_terms"] == ["deploy", "api"]
|
||||
assert data["match_type"] == "keyword"
|
||||
assert data["success_rate"] == 0.8
|
||||
|
||||
|
||||
class TestMatchContext:
|
||||
"""Tests for MatchContext dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default values."""
|
||||
context = MatchContext(query="deploy api")
|
||||
|
||||
assert context.query == "deploy api"
|
||||
assert context.task_type is None
|
||||
assert context.project_id is None
|
||||
assert context.max_results == 5
|
||||
assert context.min_score == 0.3
|
||||
assert context.require_success_rate is None
|
||||
|
||||
def test_with_all_values(self) -> None:
|
||||
"""Test with all values set."""
|
||||
project_id = uuid4()
|
||||
context = MatchContext(
|
||||
query="deploy api",
|
||||
task_type="deployment",
|
||||
project_id=project_id,
|
||||
max_results=10,
|
||||
min_score=0.5,
|
||||
require_success_rate=0.7,
|
||||
)
|
||||
|
||||
assert context.query == "deploy api"
|
||||
assert context.task_type == "deployment"
|
||||
assert context.project_id == project_id
|
||||
assert context.max_results == 10
|
||||
assert context.min_score == 0.5
|
||||
assert context.require_success_rate == 0.7
|
||||
|
||||
|
||||
class TestProcedureMatcher:
|
||||
"""Tests for ProcedureMatcher class."""
|
||||
|
||||
@pytest.fixture
|
||||
def matcher(self) -> ProcedureMatcher:
|
||||
"""Create a procedure matcher."""
|
||||
return ProcedureMatcher()
|
||||
|
||||
@pytest.fixture
|
||||
def procedures(self) -> list[Procedure]:
|
||||
"""Create test procedures."""
|
||||
return [
|
||||
create_test_procedure(
|
||||
name="deploy_api",
|
||||
trigger_pattern="deploy.*api",
|
||||
success_count=9,
|
||||
failure_count=1,
|
||||
),
|
||||
create_test_procedure(
|
||||
name="deploy_frontend",
|
||||
trigger_pattern="deploy.*frontend",
|
||||
success_count=7,
|
||||
failure_count=3,
|
||||
),
|
||||
create_test_procedure(
|
||||
name="build_project",
|
||||
trigger_pattern="build.*project",
|
||||
success_count=8,
|
||||
failure_count=2,
|
||||
),
|
||||
create_test_procedure(
|
||||
name="run_tests",
|
||||
trigger_pattern="test.*run",
|
||||
success_count=5,
|
||||
failure_count=5,
|
||||
),
|
||||
]
|
||||
|
||||
def test_match_exact_name(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test matching with exact name."""
|
||||
context = MatchContext(query="deploy_api")
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
assert len(results) > 0
|
||||
# First result should be deploy_api
|
||||
assert results[0].procedure.name == "deploy_api"
|
||||
|
||||
def test_match_partial_terms(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test matching with partial terms."""
|
||||
context = MatchContext(query="deploy")
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
assert len(results) >= 2
|
||||
# Both deploy procedures should match
|
||||
names = [r.procedure.name for r in results]
|
||||
assert "deploy_api" in names
|
||||
assert "deploy_frontend" in names
|
||||
|
||||
def test_match_with_task_type(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test matching with task type."""
|
||||
context = MatchContext(
|
||||
query="build something",
|
||||
task_type="build",
|
||||
)
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
assert len(results) > 0
|
||||
assert results[0].procedure.name == "build_project"
|
||||
|
||||
def test_match_respects_min_score(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test that matching respects minimum score."""
|
||||
context = MatchContext(
|
||||
query="completely unrelated query xyz",
|
||||
min_score=0.5,
|
||||
)
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
# Should not match anything with high min_score
|
||||
for result in results:
|
||||
assert result.score >= 0.5
|
||||
|
||||
def test_match_respects_success_rate_requirement(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test that matching respects success rate requirement."""
|
||||
context = MatchContext(
|
||||
query="deploy",
|
||||
require_success_rate=0.7,
|
||||
)
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
for result in results:
|
||||
assert result.procedure.success_rate >= 0.7
|
||||
|
||||
def test_match_respects_max_results(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test that matching respects max results."""
|
||||
context = MatchContext(
|
||||
query="deploy",
|
||||
max_results=1,
|
||||
min_score=0.0,
|
||||
)
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
assert len(results) <= 1
|
||||
|
||||
def test_match_sorts_by_score(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test that results are sorted by score."""
|
||||
context = MatchContext(query="deploy", min_score=0.0)
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
if len(results) > 1:
|
||||
scores = [r.score for r in results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_match_empty_procedures(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test matching with empty procedures list."""
|
||||
context = MatchContext(query="deploy")
|
||||
|
||||
results = matcher.match([], context)
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestProcedureMatcherRankByRelevance:
|
||||
"""Tests for rank_by_relevance method."""
|
||||
|
||||
@pytest.fixture
|
||||
def matcher(self) -> ProcedureMatcher:
|
||||
"""Create a procedure matcher."""
|
||||
return ProcedureMatcher()
|
||||
|
||||
def test_rank_by_relevance(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test ranking by relevance."""
|
||||
procedures = [
|
||||
create_test_procedure(name="unrelated", trigger_pattern="something else"),
|
||||
create_test_procedure(name="deploy_api", trigger_pattern="deploy.*api"),
|
||||
create_test_procedure(
|
||||
name="deploy_frontend", trigger_pattern="deploy.*frontend"
|
||||
),
|
||||
]
|
||||
|
||||
ranked = matcher.rank_by_relevance(procedures, "deploy")
|
||||
|
||||
# Deploy procedures should be ranked first
|
||||
assert ranked[0].name in ["deploy_api", "deploy_frontend"]
|
||||
|
||||
def test_rank_by_relevance_empty(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test ranking empty list."""
|
||||
ranked = matcher.rank_by_relevance([], "deploy")
|
||||
|
||||
assert ranked == []
|
||||
|
||||
|
||||
class TestProcedureMatcherSuggestProcedures:
|
||||
"""Tests for suggest_procedures method."""
|
||||
|
||||
@pytest.fixture
|
||||
def matcher(self) -> ProcedureMatcher:
|
||||
"""Create a procedure matcher."""
|
||||
return ProcedureMatcher()
|
||||
|
||||
@pytest.fixture
|
||||
def procedures(self) -> list[Procedure]:
|
||||
"""Create test procedures."""
|
||||
return [
|
||||
create_test_procedure(
|
||||
name="deploy_api",
|
||||
trigger_pattern="deploy api",
|
||||
success_count=9,
|
||||
failure_count=1,
|
||||
),
|
||||
create_test_procedure(
|
||||
name="bad_deploy",
|
||||
trigger_pattern="deploy bad",
|
||||
success_count=2,
|
||||
failure_count=8,
|
||||
),
|
||||
]
|
||||
|
||||
def test_suggest_procedures(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test suggesting procedures."""
|
||||
suggestions = matcher.suggest_procedures(
|
||||
procedures,
|
||||
"deploy",
|
||||
min_success_rate=0.5,
|
||||
)
|
||||
|
||||
assert len(suggestions) > 0
|
||||
# Only high success rate should be suggested
|
||||
for s in suggestions:
|
||||
assert s.procedure.success_rate >= 0.5
|
||||
|
||||
def test_suggest_procedures_limits_results(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test that suggestions are limited."""
|
||||
suggestions = matcher.suggest_procedures(
|
||||
procedures,
|
||||
"deploy",
|
||||
max_suggestions=1,
|
||||
)
|
||||
|
||||
assert len(suggestions) <= 1
|
||||
|
||||
|
||||
class TestGetProcedureMatcher:
|
||||
"""Tests for singleton getter."""
|
||||
|
||||
def test_get_procedure_matcher_returns_instance(self) -> None:
|
||||
"""Test that getter returns instance."""
|
||||
matcher = get_procedure_matcher()
|
||||
|
||||
assert matcher is not None
|
||||
assert isinstance(matcher, ProcedureMatcher)
|
||||
|
||||
def test_get_procedure_matcher_returns_same_instance(self) -> None:
|
||||
"""Test that getter returns same instance (singleton)."""
|
||||
matcher1 = get_procedure_matcher()
|
||||
matcher2 = get_procedure_matcher()
|
||||
|
||||
assert matcher1 is matcher2
|
||||
|
||||
|
||||
class TestProcedureMatcherExtractTerms:
|
||||
"""Tests for term extraction."""
|
||||
|
||||
@pytest.fixture
|
||||
def matcher(self) -> ProcedureMatcher:
|
||||
"""Create a procedure matcher."""
|
||||
return ProcedureMatcher()
|
||||
|
||||
def test_extract_terms_basic(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test basic term extraction."""
|
||||
terms = matcher._extract_terms("deploy the api")
|
||||
|
||||
assert "deploy" in terms
|
||||
assert "the" in terms
|
||||
assert "api" in terms
|
||||
|
||||
def test_extract_terms_removes_special_chars(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test that special characters are removed."""
|
||||
terms = matcher._extract_terms("deploy.api!now")
|
||||
|
||||
assert "deploy" in terms
|
||||
assert "api" in terms
|
||||
assert "now" in terms
|
||||
assert "." not in terms
|
||||
assert "!" not in terms
|
||||
|
||||
def test_extract_terms_filters_short(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test that short terms are filtered."""
|
||||
terms = matcher._extract_terms("a big api")
|
||||
|
||||
assert "a" not in terms
|
||||
assert "big" in terms
|
||||
assert "api" in terms
|
||||
|
||||
def test_extract_terms_lowercases(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test that terms are lowercased."""
|
||||
terms = matcher._extract_terms("Deploy API")
|
||||
|
||||
assert "deploy" in terms
|
||||
assert "api" in terms
|
||||
assert "Deploy" not in terms
|
||||
assert "API" not in terms
|
||||
Reference in New Issue
Block a user