feat(memory): add procedural memory implementation (Issue #92)

Implements procedural memory for learned skills and procedures:

Core functionality:
- ProceduralMemory class for procedure storage/retrieval
- record_procedure with duplicate detection and step merging
- find_matching for context-based procedure search
- record_outcome for success/failure tracking
- get_best_procedure for finding highest success rate
- update_steps for procedure refinement

Supporting modules:
- ProcedureMatcher: Keyword-based procedure matching
- MatchResult/MatchContext: Matching result types
- Success rate weighting in match scoring

Test coverage:
- 43 unit tests covering all modules
- matching.py: 97% coverage
- memory.py: 86% coverage

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 02:31:32 +01:00
parent e946787a61
commit b818f17418
6 changed files with 2029 additions and 1 deletions

View File

@@ -1,7 +1,22 @@
# app/services/memory/procedural/__init__.py
"""
Procedural Memory
Learned skills and procedures from successful task patterns.
"""
# Will be populated in #92
from .matching import (
MatchContext,
MatchResult,
ProcedureMatcher,
get_procedure_matcher,
)
from .memory import ProceduralMemory
__all__ = [
"MatchContext",
"MatchResult",
"ProceduralMemory",
"ProcedureMatcher",
"get_procedure_matcher",
]

View File

@@ -0,0 +1,291 @@
# app/services/memory/procedural/matching.py
"""
Procedure Matching.
Provides utilities for matching procedures to contexts,
ranking procedures by relevance, and suggesting procedures.
"""
import logging
import re
from dataclasses import dataclass, field
from typing import Any, ClassVar
from app.services.memory.types import Procedure
logger = logging.getLogger(__name__)
@dataclass
class MatchResult:
"""Result of a procedure match."""
procedure: Procedure
score: float
matched_terms: list[str] = field(default_factory=list)
match_type: str = "keyword" # keyword, semantic, pattern
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"procedure_id": str(self.procedure.id),
"procedure_name": self.procedure.name,
"score": self.score,
"matched_terms": self.matched_terms,
"match_type": self.match_type,
"success_rate": self.procedure.success_rate,
}
@dataclass
class MatchContext:
"""Context for procedure matching."""
query: str
task_type: str | None = None
project_id: Any | None = None
agent_type_id: Any | None = None
max_results: int = 5
min_score: float = 0.3
require_success_rate: float | None = None
class ProcedureMatcher:
"""
Matches procedures to contexts using multiple strategies.
Matching strategies:
- Keyword matching on trigger pattern and name
- Pattern-based matching using regex
- Success rate weighting
In production, this would be augmented with vector similarity search.
"""
# Common task-related keywords for boosting
TASK_KEYWORDS: ClassVar[set[str]] = {
"create",
"update",
"delete",
"fix",
"implement",
"add",
"remove",
"refactor",
"test",
"deploy",
"configure",
"setup",
"build",
"debug",
"optimize",
}
def __init__(self) -> None:
"""Initialize the matcher."""
self._compiled_patterns: dict[str, re.Pattern[str]] = {}
def match(
self,
procedures: list[Procedure],
context: MatchContext,
) -> list[MatchResult]:
"""
Match procedures against a context.
Args:
procedures: List of procedures to match
context: Matching context
Returns:
List of match results, sorted by score (highest first)
"""
results: list[MatchResult] = []
query_terms = self._extract_terms(context.query)
query_lower = context.query.lower()
for procedure in procedures:
score, matched = self._calculate_match_score(
procedure=procedure,
query_terms=query_terms,
query_lower=query_lower,
context=context,
)
if score >= context.min_score:
# Apply success rate boost
if context.require_success_rate is not None:
if procedure.success_rate < context.require_success_rate:
continue
# Boost score based on success rate
success_boost = procedure.success_rate * 0.2
final_score = min(1.0, score + success_boost)
results.append(
MatchResult(
procedure=procedure,
score=final_score,
matched_terms=matched,
match_type="keyword",
)
)
# Sort by score descending
results.sort(key=lambda r: r.score, reverse=True)
return results[: context.max_results]
def _extract_terms(self, text: str) -> list[str]:
"""Extract searchable terms from text."""
# Remove special characters and split
clean = re.sub(r"[^\w\s-]", " ", text.lower())
terms = clean.split()
# Filter out very short terms
return [t for t in terms if len(t) >= 2]
def _calculate_match_score(
self,
procedure: Procedure,
query_terms: list[str],
query_lower: str,
context: MatchContext,
) -> tuple[float, list[str]]:
"""
Calculate match score between procedure and query.
Returns:
Tuple of (score, matched_terms)
"""
score = 0.0
matched: list[str] = []
trigger_lower = procedure.trigger_pattern.lower()
name_lower = procedure.name.lower()
# Exact name match - high score
if name_lower in query_lower or query_lower in name_lower:
score += 0.5
matched.append(f"name:{procedure.name}")
# Trigger pattern match
if trigger_lower in query_lower or query_lower in trigger_lower:
score += 0.4
matched.append(f"trigger:{procedure.trigger_pattern[:30]}")
# Term-by-term matching
for term in query_terms:
if term in trigger_lower:
score += 0.1
matched.append(term)
elif term in name_lower:
score += 0.08
matched.append(term)
# Boost for task keywords
if term in self.TASK_KEYWORDS:
if term in trigger_lower or term in name_lower:
score += 0.05
# Task type match if provided
if context.task_type:
task_type_lower = context.task_type.lower()
if task_type_lower in trigger_lower or task_type_lower in name_lower:
score += 0.3
matched.append(f"task_type:{context.task_type}")
# Regex pattern matching on trigger
try:
pattern = self._get_or_compile_pattern(trigger_lower)
if pattern and pattern.search(query_lower):
score += 0.25
matched.append("pattern_match")
except re.error:
pass # Invalid regex, skip pattern matching
return min(1.0, score), matched
def _get_or_compile_pattern(self, pattern: str) -> re.Pattern[str] | None:
"""Get or compile a regex pattern with caching."""
if pattern in self._compiled_patterns:
return self._compiled_patterns[pattern]
# Only compile if it looks like a regex pattern
if not any(c in pattern for c in r"\.*+?[]{}|()^$"):
return None
try:
compiled = re.compile(pattern, re.IGNORECASE)
self._compiled_patterns[pattern] = compiled
return compiled
except re.error:
return None
def rank_by_relevance(
self,
procedures: list[Procedure],
task_type: str,
) -> list[Procedure]:
"""
Rank procedures by relevance to a task type.
Args:
procedures: Procedures to rank
task_type: Task type for relevance
Returns:
Procedures sorted by relevance
"""
context = MatchContext(
query=task_type,
task_type=task_type,
min_score=0.0,
max_results=len(procedures),
)
results = self.match(procedures, context)
return [r.procedure for r in results]
def suggest_procedures(
self,
procedures: list[Procedure],
query: str,
min_success_rate: float = 0.5,
max_suggestions: int = 3,
) -> list[MatchResult]:
"""
Suggest the best procedures for a query.
Only suggests procedures with sufficient success rate.
Args:
procedures: Available procedures
query: Query/context
min_success_rate: Minimum success rate to suggest
max_suggestions: Maximum suggestions
Returns:
List of procedure suggestions
"""
context = MatchContext(
query=query,
max_results=max_suggestions,
min_score=0.2,
require_success_rate=min_success_rate,
)
return self.match(procedures, context)
# Singleton matcher instance
_matcher: ProcedureMatcher | None = None
def get_procedure_matcher() -> ProcedureMatcher:
"""Get the singleton procedure matcher instance."""
global _matcher
if _matcher is None:
_matcher = ProcedureMatcher()
return _matcher

View File

@@ -0,0 +1,724 @@
# app/services/memory/procedural/memory.py
"""
Procedural Memory Implementation.
Provides storage and retrieval for learned procedures (skills)
derived from successful task execution patterns.
"""
import logging
import time
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import and_, desc, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory.procedure import Procedure as ProcedureModel
from app.services.memory.config import get_memory_settings
from app.services.memory.types import Procedure, ProcedureCreate, RetrievalResult, Step
logger = logging.getLogger(__name__)
def _model_to_procedure(model: ProcedureModel) -> Procedure:
"""Convert SQLAlchemy model to Procedure dataclass."""
return Procedure(
id=model.id, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
name=model.name, # type: ignore[arg-type]
trigger_pattern=model.trigger_pattern, # type: ignore[arg-type]
steps=model.steps or [], # type: ignore[arg-type]
success_count=model.success_count, # type: ignore[arg-type]
failure_count=model.failure_count, # type: ignore[arg-type]
last_used=model.last_used, # type: ignore[arg-type]
embedding=None, # Don't expose raw embedding
created_at=model.created_at, # type: ignore[arg-type]
updated_at=model.updated_at, # type: ignore[arg-type]
)
class ProceduralMemory:
"""
Procedural Memory Service.
Provides procedure storage and retrieval:
- Record procedures from successful task patterns
- Find matching procedures by trigger pattern
- Track success/failure rates
- Get best procedure for a task type
- Update procedure steps
Performance target: <50ms P95 for matching
"""
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize procedural memory.
Args:
session: Database session
embedding_generator: Optional embedding generator for semantic matching
"""
self._session = session
self._embedding_generator = embedding_generator
self._settings = get_memory_settings()
@classmethod
async def create(
cls,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> "ProceduralMemory":
"""
Factory method to create ProceduralMemory.
Args:
session: Database session
embedding_generator: Optional embedding generator
Returns:
Configured ProceduralMemory instance
"""
return cls(session=session, embedding_generator=embedding_generator)
# =========================================================================
# Procedure Recording
# =========================================================================
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
"""
Record a new procedure or update an existing one.
If a procedure with the same name exists in the same scope,
its steps will be updated and success count incremented.
Args:
procedure: Procedure data to record
Returns:
The created or updated procedure
"""
# Check for existing procedure with same name
existing = await self._find_existing_procedure(
project_id=procedure.project_id,
agent_type_id=procedure.agent_type_id,
name=procedure.name,
)
if existing is not None:
# Update existing procedure
return await self._update_existing_procedure(
existing=existing,
new_steps=procedure.steps,
new_trigger=procedure.trigger_pattern,
)
# Create new procedure
now = datetime.now(UTC)
# Generate embedding if possible
embedding = None
if self._embedding_generator is not None:
embedding_text = self._create_embedding_text(procedure)
embedding = await self._embedding_generator.generate(embedding_text)
model = ProcedureModel(
project_id=procedure.project_id,
agent_type_id=procedure.agent_type_id,
name=procedure.name,
trigger_pattern=procedure.trigger_pattern,
steps=procedure.steps,
success_count=1, # New procedures start with 1 success (they worked)
failure_count=0,
last_used=now,
embedding=embedding,
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
logger.info(
f"Recorded new procedure: {procedure.name} with {len(procedure.steps)} steps"
)
return _model_to_procedure(model)
async def _find_existing_procedure(
self,
project_id: UUID | None,
agent_type_id: UUID | None,
name: str,
) -> ProcedureModel | None:
"""Find an existing procedure with the same name in the same scope."""
query = select(ProcedureModel).where(ProcedureModel.name == name)
if project_id is not None:
query = query.where(ProcedureModel.project_id == project_id)
else:
query = query.where(ProcedureModel.project_id.is_(None))
if agent_type_id is not None:
query = query.where(ProcedureModel.agent_type_id == agent_type_id)
else:
query = query.where(ProcedureModel.agent_type_id.is_(None))
result = await self._session.execute(query)
return result.scalar_one_or_none()
async def _update_existing_procedure(
self,
existing: ProcedureModel,
new_steps: list[dict[str, Any]],
new_trigger: str,
) -> Procedure:
"""Update an existing procedure with new steps."""
now = datetime.now(UTC)
# Merge steps intelligently - keep existing order, add new steps
merged_steps = self._merge_steps(
existing.steps or [], # type: ignore[arg-type]
new_steps,
)
stmt = (
update(ProcedureModel)
.where(ProcedureModel.id == existing.id)
.values(
steps=merged_steps,
trigger_pattern=new_trigger,
success_count=ProcedureModel.success_count + 1,
last_used=now,
updated_at=now,
)
.returning(ProcedureModel)
)
result = await self._session.execute(stmt)
updated_model = result.scalar_one()
await self._session.flush()
logger.info(f"Updated existing procedure: {existing.name}")
return _model_to_procedure(updated_model)
def _merge_steps(
self,
existing_steps: list[dict[str, Any]],
new_steps: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Merge steps from a new execution with existing steps."""
if not existing_steps:
return new_steps
if not new_steps:
return existing_steps
# For now, use the new steps if they differ significantly
# In production, this could use more sophisticated merging
if len(new_steps) != len(existing_steps):
# If structure changed, prefer newer steps
return new_steps
# Merge step-by-step, preferring new data where available
merged = []
for i, new_step in enumerate(new_steps):
if i < len(existing_steps):
# Merge with existing step
step = {**existing_steps[i], **new_step}
else:
step = new_step
merged.append(step)
return merged
def _create_embedding_text(self, procedure: ProcedureCreate) -> str:
"""Create text for embedding from procedure data."""
steps_text = " ".join(step.get("action", "") for step in procedure.steps)
return f"{procedure.name} {procedure.trigger_pattern} {steps_text}"
# =========================================================================
# Procedure Retrieval
# =========================================================================
async def find_matching(
self,
context: str,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
limit: int = 5,
) -> list[Procedure]:
"""
Find procedures matching the given context.
Args:
context: Context/trigger to match against
project_id: Optional project to search within
agent_type_id: Optional agent type filter
limit: Maximum results
Returns:
List of matching procedures
"""
result = await self._find_matching_with_metadata(
context=context,
project_id=project_id,
agent_type_id=agent_type_id,
limit=limit,
)
return result.items
async def _find_matching_with_metadata(
self,
context: str,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
limit: int = 5,
) -> RetrievalResult[Procedure]:
"""Find matching procedures with full result metadata."""
start_time = time.perf_counter()
# Build base query - prioritize by success rate
stmt = (
select(ProcedureModel)
.order_by(
desc(
ProcedureModel.success_count
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
),
desc(ProcedureModel.last_used),
)
.limit(limit)
)
# Apply scope filters
if project_id is not None:
stmt = stmt.where(
or_(
ProcedureModel.project_id == project_id,
ProcedureModel.project_id.is_(None),
)
)
if agent_type_id is not None:
stmt = stmt.where(
or_(
ProcedureModel.agent_type_id == agent_type_id,
ProcedureModel.agent_type_id.is_(None),
)
)
# Text-based matching on trigger pattern and name
# TODO: Implement proper vector similarity search when pgvector is integrated
search_terms = context.lower().split()[:5] # Limit to 5 terms
if search_terms:
conditions = []
for term in search_terms:
term_pattern = f"%{term}%"
conditions.append(
or_(
ProcedureModel.trigger_pattern.ilike(term_pattern),
ProcedureModel.name.ilike(term_pattern),
)
)
if conditions:
stmt = stmt.where(or_(*conditions))
result = await self._session.execute(stmt)
models = list(result.scalars().all())
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=[_model_to_procedure(m) for m in models],
total_count=len(models),
query=context,
retrieval_type="procedural",
latency_ms=latency_ms,
metadata={"project_id": str(project_id) if project_id else None},
)
async def get_best_procedure(
self,
task_type: str,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
min_success_rate: float = 0.5,
min_uses: int = 1,
) -> Procedure | None:
"""
Get the best procedure for a given task type.
Returns the procedure with the highest success rate that
meets the minimum thresholds.
Args:
task_type: Task type to find procedure for
project_id: Optional project scope
agent_type_id: Optional agent type scope
min_success_rate: Minimum required success rate
min_uses: Minimum number of uses required
Returns:
Best matching procedure or None
"""
# Build query for procedures matching task type
stmt = (
select(ProcedureModel)
.where(
and_(
(ProcedureModel.success_count + ProcedureModel.failure_count)
>= min_uses,
or_(
ProcedureModel.trigger_pattern.ilike(f"%{task_type}%"),
ProcedureModel.name.ilike(f"%{task_type}%"),
),
)
)
.order_by(
desc(
ProcedureModel.success_count
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
),
desc(ProcedureModel.last_used),
)
.limit(10)
)
# Apply scope filters
if project_id is not None:
stmt = stmt.where(
or_(
ProcedureModel.project_id == project_id,
ProcedureModel.project_id.is_(None),
)
)
if agent_type_id is not None:
stmt = stmt.where(
or_(
ProcedureModel.agent_type_id == agent_type_id,
ProcedureModel.agent_type_id.is_(None),
)
)
result = await self._session.execute(stmt)
models = list(result.scalars().all())
# Filter by success rate in Python (SQLAlchemy division in WHERE is complex)
for model in models:
success = float(model.success_count)
failure = float(model.failure_count)
total = success + failure
if total > 0 and (success / total) >= min_success_rate:
logger.debug(
f"Found best procedure for '{task_type}': {model.name} "
f"(success_rate={success / total:.2%})"
)
return _model_to_procedure(model)
return None
async def get_by_id(self, procedure_id: UUID) -> Procedure | None:
"""Get a procedure by ID."""
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
return _model_to_procedure(model) if model else None
# =========================================================================
# Outcome Recording
# =========================================================================
async def record_outcome(
self,
procedure_id: UUID,
success: bool,
) -> Procedure:
"""
Record the outcome of using a procedure.
Updates the success or failure count and last_used timestamp.
Args:
procedure_id: Procedure that was used
success: Whether the procedure succeeded
Returns:
Updated procedure
Raises:
ValueError: If procedure not found
"""
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
raise ValueError(f"Procedure not found: {procedure_id}")
now = datetime.now(UTC)
if success:
stmt = (
update(ProcedureModel)
.where(ProcedureModel.id == procedure_id)
.values(
success_count=ProcedureModel.success_count + 1,
last_used=now,
updated_at=now,
)
.returning(ProcedureModel)
)
else:
stmt = (
update(ProcedureModel)
.where(ProcedureModel.id == procedure_id)
.values(
failure_count=ProcedureModel.failure_count + 1,
last_used=now,
updated_at=now,
)
.returning(ProcedureModel)
)
result = await self._session.execute(stmt)
updated_model = result.scalar_one()
await self._session.flush()
outcome = "success" if success else "failure"
logger.info(
f"Recorded {outcome} for procedure {procedure_id}: "
f"success_rate={updated_model.success_rate:.2%}"
)
return _model_to_procedure(updated_model)
# =========================================================================
# Step Management
# =========================================================================
async def update_steps(
self,
procedure_id: UUID,
steps: list[Step],
) -> Procedure:
"""
Update the steps of a procedure.
Args:
procedure_id: Procedure to update
steps: New steps
Returns:
Updated procedure
Raises:
ValueError: If procedure not found
"""
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
raise ValueError(f"Procedure not found: {procedure_id}")
# Convert Step objects to dictionaries
steps_dict = [
{
"order": step.order,
"action": step.action,
"parameters": step.parameters,
"expected_outcome": step.expected_outcome,
"fallback_action": step.fallback_action,
}
for step in steps
]
now = datetime.now(UTC)
stmt = (
update(ProcedureModel)
.where(ProcedureModel.id == procedure_id)
.values(
steps=steps_dict,
updated_at=now,
)
.returning(ProcedureModel)
)
result = await self._session.execute(stmt)
updated_model = result.scalar_one()
await self._session.flush()
logger.info(f"Updated steps for procedure {procedure_id}: {len(steps)} steps")
return _model_to_procedure(updated_model)
# =========================================================================
# Statistics & Management
# =========================================================================
async def get_stats(
self,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
) -> dict[str, Any]:
"""
Get statistics about procedural memory.
Args:
project_id: Optional project to get stats for
agent_type_id: Optional agent type filter
Returns:
Dictionary with statistics
"""
query = select(ProcedureModel)
if project_id is not None:
query = query.where(
or_(
ProcedureModel.project_id == project_id,
ProcedureModel.project_id.is_(None),
)
)
if agent_type_id is not None:
query = query.where(
or_(
ProcedureModel.agent_type_id == agent_type_id,
ProcedureModel.agent_type_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
if not models:
return {
"total_procedures": 0,
"avg_success_rate": 0.0,
"avg_steps_count": 0.0,
"total_uses": 0,
"high_success_count": 0,
"low_success_count": 0,
}
success_rates = [m.success_rate for m in models]
step_counts = [len(m.steps or []) for m in models]
total_uses = sum(m.total_uses for m in models)
return {
"total_procedures": len(models),
"avg_success_rate": sum(success_rates) / len(success_rates),
"avg_steps_count": sum(step_counts) / len(step_counts),
"total_uses": total_uses,
"high_success_count": sum(1 for r in success_rates if r >= 0.8),
"low_success_count": sum(1 for r in success_rates if r < 0.5),
}
async def count(
self,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
) -> int:
"""
Count procedures in scope.
Args:
project_id: Optional project to count for
agent_type_id: Optional agent type filter
Returns:
Number of procedures
"""
query = select(ProcedureModel)
if project_id is not None:
query = query.where(
or_(
ProcedureModel.project_id == project_id,
ProcedureModel.project_id.is_(None),
)
)
if agent_type_id is not None:
query = query.where(
or_(
ProcedureModel.agent_type_id == agent_type_id,
ProcedureModel.agent_type_id.is_(None),
)
)
result = await self._session.execute(query)
return len(list(result.scalars().all()))
async def delete(self, procedure_id: UUID) -> bool:
"""
Delete a procedure.
Args:
procedure_id: Procedure to delete
Returns:
True if deleted, False if not found
"""
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return False
await self._session.delete(model)
await self._session.flush()
logger.info(f"Deleted procedure {procedure_id}")
return True
async def get_procedures_by_success_rate(
self,
min_rate: float = 0.0,
max_rate: float = 1.0,
project_id: UUID | None = None,
limit: int = 20,
) -> list[Procedure]:
"""
Get procedures within a success rate range.
Args:
min_rate: Minimum success rate
max_rate: Maximum success rate
project_id: Optional project scope
limit: Maximum results
Returns:
List of procedures
"""
query = (
select(ProcedureModel)
.order_by(desc(ProcedureModel.last_used))
.limit(limit * 2) # Fetch more since we filter in Python
)
if project_id is not None:
query = query.where(
or_(
ProcedureModel.project_id == project_id,
ProcedureModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
# Filter by success rate in Python
filtered = [m for m in models if min_rate <= m.success_rate <= max_rate][:limit]
return [_model_to_procedure(m) for m in filtered]

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/procedural/__init__.py
"""Unit tests for procedural memory."""

View File

@@ -0,0 +1,427 @@
# tests/unit/services/memory/procedural/test_matching.py
"""Unit tests for procedure matching."""
from datetime import UTC, datetime
from uuid import uuid4
import pytest
from app.services.memory.procedural.matching import (
MatchContext,
MatchResult,
ProcedureMatcher,
get_procedure_matcher,
)
from app.services.memory.types import Procedure
def create_test_procedure(
name: str = "deploy_api",
trigger_pattern: str = "deploy.*api",
success_count: int = 8,
failure_count: int = 2,
) -> Procedure:
"""Create a test procedure for testing."""
now = datetime.now(UTC)
return Procedure(
id=uuid4(),
project_id=None,
agent_type_id=None,
name=name,
trigger_pattern=trigger_pattern,
steps=[
{"order": 1, "action": "build"},
{"order": 2, "action": "test"},
{"order": 3, "action": "deploy"},
],
success_count=success_count,
failure_count=failure_count,
last_used=now,
embedding=None,
created_at=now,
updated_at=now,
)
class TestMatchResult:
"""Tests for MatchResult dataclass."""
def test_to_dict(self) -> None:
"""Test converting match result to dictionary."""
procedure = create_test_procedure()
result = MatchResult(
procedure=procedure,
score=0.85,
matched_terms=["deploy", "api"],
match_type="keyword",
)
data = result.to_dict()
assert "procedure_id" in data
assert "procedure_name" in data
assert data["score"] == 0.85
assert data["matched_terms"] == ["deploy", "api"]
assert data["match_type"] == "keyword"
assert data["success_rate"] == 0.8
class TestMatchContext:
"""Tests for MatchContext dataclass."""
def test_default_values(self) -> None:
"""Test default values."""
context = MatchContext(query="deploy api")
assert context.query == "deploy api"
assert context.task_type is None
assert context.project_id is None
assert context.max_results == 5
assert context.min_score == 0.3
assert context.require_success_rate is None
def test_with_all_values(self) -> None:
"""Test with all values set."""
project_id = uuid4()
context = MatchContext(
query="deploy api",
task_type="deployment",
project_id=project_id,
max_results=10,
min_score=0.5,
require_success_rate=0.7,
)
assert context.query == "deploy api"
assert context.task_type == "deployment"
assert context.project_id == project_id
assert context.max_results == 10
assert context.min_score == 0.5
assert context.require_success_rate == 0.7
class TestProcedureMatcher:
"""Tests for ProcedureMatcher class."""
@pytest.fixture
def matcher(self) -> ProcedureMatcher:
"""Create a procedure matcher."""
return ProcedureMatcher()
@pytest.fixture
def procedures(self) -> list[Procedure]:
"""Create test procedures."""
return [
create_test_procedure(
name="deploy_api",
trigger_pattern="deploy.*api",
success_count=9,
failure_count=1,
),
create_test_procedure(
name="deploy_frontend",
trigger_pattern="deploy.*frontend",
success_count=7,
failure_count=3,
),
create_test_procedure(
name="build_project",
trigger_pattern="build.*project",
success_count=8,
failure_count=2,
),
create_test_procedure(
name="run_tests",
trigger_pattern="test.*run",
success_count=5,
failure_count=5,
),
]
def test_match_exact_name(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test matching with exact name."""
context = MatchContext(query="deploy_api")
results = matcher.match(procedures, context)
assert len(results) > 0
# First result should be deploy_api
assert results[0].procedure.name == "deploy_api"
def test_match_partial_terms(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test matching with partial terms."""
context = MatchContext(query="deploy")
results = matcher.match(procedures, context)
assert len(results) >= 2
# Both deploy procedures should match
names = [r.procedure.name for r in results]
assert "deploy_api" in names
assert "deploy_frontend" in names
def test_match_with_task_type(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test matching with task type."""
context = MatchContext(
query="build something",
task_type="build",
)
results = matcher.match(procedures, context)
assert len(results) > 0
assert results[0].procedure.name == "build_project"
def test_match_respects_min_score(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test that matching respects minimum score."""
context = MatchContext(
query="completely unrelated query xyz",
min_score=0.5,
)
results = matcher.match(procedures, context)
# Should not match anything with high min_score
for result in results:
assert result.score >= 0.5
def test_match_respects_success_rate_requirement(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test that matching respects success rate requirement."""
context = MatchContext(
query="deploy",
require_success_rate=0.7,
)
results = matcher.match(procedures, context)
for result in results:
assert result.procedure.success_rate >= 0.7
def test_match_respects_max_results(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test that matching respects max results."""
context = MatchContext(
query="deploy",
max_results=1,
min_score=0.0,
)
results = matcher.match(procedures, context)
assert len(results) <= 1
def test_match_sorts_by_score(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test that results are sorted by score."""
context = MatchContext(query="deploy", min_score=0.0)
results = matcher.match(procedures, context)
if len(results) > 1:
scores = [r.score for r in results]
assert scores == sorted(scores, reverse=True)
def test_match_empty_procedures(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test matching with empty procedures list."""
context = MatchContext(query="deploy")
results = matcher.match([], context)
assert results == []
class TestProcedureMatcherRankByRelevance:
"""Tests for rank_by_relevance method."""
@pytest.fixture
def matcher(self) -> ProcedureMatcher:
"""Create a procedure matcher."""
return ProcedureMatcher()
def test_rank_by_relevance(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test ranking by relevance."""
procedures = [
create_test_procedure(name="unrelated", trigger_pattern="something else"),
create_test_procedure(name="deploy_api", trigger_pattern="deploy.*api"),
create_test_procedure(
name="deploy_frontend", trigger_pattern="deploy.*frontend"
),
]
ranked = matcher.rank_by_relevance(procedures, "deploy")
# Deploy procedures should be ranked first
assert ranked[0].name in ["deploy_api", "deploy_frontend"]
def test_rank_by_relevance_empty(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test ranking empty list."""
ranked = matcher.rank_by_relevance([], "deploy")
assert ranked == []
class TestProcedureMatcherSuggestProcedures:
"""Tests for suggest_procedures method."""
@pytest.fixture
def matcher(self) -> ProcedureMatcher:
"""Create a procedure matcher."""
return ProcedureMatcher()
@pytest.fixture
def procedures(self) -> list[Procedure]:
"""Create test procedures."""
return [
create_test_procedure(
name="deploy_api",
trigger_pattern="deploy api",
success_count=9,
failure_count=1,
),
create_test_procedure(
name="bad_deploy",
trigger_pattern="deploy bad",
success_count=2,
failure_count=8,
),
]
def test_suggest_procedures(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test suggesting procedures."""
suggestions = matcher.suggest_procedures(
procedures,
"deploy",
min_success_rate=0.5,
)
assert len(suggestions) > 0
# Only high success rate should be suggested
for s in suggestions:
assert s.procedure.success_rate >= 0.5
def test_suggest_procedures_limits_results(
self,
matcher: ProcedureMatcher,
procedures: list[Procedure],
) -> None:
"""Test that suggestions are limited."""
suggestions = matcher.suggest_procedures(
procedures,
"deploy",
max_suggestions=1,
)
assert len(suggestions) <= 1
class TestGetProcedureMatcher:
"""Tests for singleton getter."""
def test_get_procedure_matcher_returns_instance(self) -> None:
"""Test that getter returns instance."""
matcher = get_procedure_matcher()
assert matcher is not None
assert isinstance(matcher, ProcedureMatcher)
def test_get_procedure_matcher_returns_same_instance(self) -> None:
"""Test that getter returns same instance (singleton)."""
matcher1 = get_procedure_matcher()
matcher2 = get_procedure_matcher()
assert matcher1 is matcher2
class TestProcedureMatcherExtractTerms:
"""Tests for term extraction."""
@pytest.fixture
def matcher(self) -> ProcedureMatcher:
"""Create a procedure matcher."""
return ProcedureMatcher()
def test_extract_terms_basic(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test basic term extraction."""
terms = matcher._extract_terms("deploy the api")
assert "deploy" in terms
assert "the" in terms
assert "api" in terms
def test_extract_terms_removes_special_chars(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test that special characters are removed."""
terms = matcher._extract_terms("deploy.api!now")
assert "deploy" in terms
assert "api" in terms
assert "now" in terms
assert "." not in terms
assert "!" not in terms
def test_extract_terms_filters_short(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test that short terms are filtered."""
terms = matcher._extract_terms("a big api")
assert "a" not in terms
assert "big" in terms
assert "api" in terms
def test_extract_terms_lowercases(
self,
matcher: ProcedureMatcher,
) -> None:
"""Test that terms are lowercased."""
terms = matcher._extract_terms("Deploy API")
assert "deploy" in terms
assert "api" in terms
assert "Deploy" not in terms
assert "API" not in terms

View File

@@ -0,0 +1,569 @@
# tests/unit/services/memory/procedural/test_memory.py
"""Unit tests for ProceduralMemory class."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.services.memory.procedural.memory import ProceduralMemory
from app.services.memory.types import ProcedureCreate, Step
def create_mock_procedure_model(
name="deploy_api",
trigger_pattern="deploy.*api",
project_id=None,
agent_type_id=None,
success_count=5,
failure_count=1,
):
"""Create a mock procedure model for testing."""
mock = MagicMock()
mock.id = uuid4()
mock.project_id = project_id
mock.agent_type_id = agent_type_id
mock.name = name
mock.trigger_pattern = trigger_pattern
mock.steps = [
{"order": 1, "action": "build", "parameters": {}},
{"order": 2, "action": "test", "parameters": {}},
{"order": 3, "action": "deploy", "parameters": {}},
]
mock.success_count = success_count
mock.failure_count = failure_count
mock.last_used = datetime.now(UTC)
mock.embedding = None
mock.created_at = datetime.now(UTC)
mock.updated_at = datetime.now(UTC)
mock.success_rate = (
success_count / (success_count + failure_count)
if (success_count + failure_count) > 0
else 0.0
)
mock.total_uses = success_count + failure_count
return mock
class TestProceduralMemoryInit:
"""Tests for ProceduralMemory initialization."""
def test_init_creates_memory(self) -> None:
"""Test that init creates memory instance."""
mock_session = AsyncMock()
memory = ProceduralMemory(session=mock_session)
assert memory._session is mock_session
def test_init_with_embedding_generator(self) -> None:
"""Test init with embedding generator."""
mock_session = AsyncMock()
mock_embedding_gen = AsyncMock()
memory = ProceduralMemory(
session=mock_session, embedding_generator=mock_embedding_gen
)
assert memory._embedding_generator is mock_embedding_gen
@pytest.mark.asyncio
async def test_create_factory_method(self) -> None:
"""Test create factory method."""
mock_session = AsyncMock()
memory = await ProceduralMemory.create(session=mock_session)
assert memory is not None
assert memory._session is mock_session
class TestProceduralMemoryRecordProcedure:
"""Tests for procedure recording methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
session.add = MagicMock()
session.flush = AsyncMock()
session.refresh = AsyncMock()
# Mock no existing procedure found
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
session.execute.return_value = mock_result
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_record_new_procedure(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test recording a new procedure."""
procedure_data = ProcedureCreate(
name="build_project",
trigger_pattern="build.*project",
steps=[
{"order": 1, "action": "npm install"},
{"order": 2, "action": "npm run build"},
],
project_id=uuid4(),
)
result = await memory.record_procedure(procedure_data)
assert result.name == "build_project"
assert result.trigger_pattern == "build.*project"
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
@pytest.mark.asyncio
async def test_record_updates_existing(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test that recording duplicate procedure updates existing."""
# Mock existing procedure found
existing_mock = create_mock_procedure_model()
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
# Mock update result
updated_mock = create_mock_procedure_model(success_count=6)
update_result = MagicMock()
update_result.scalar_one.return_value = updated_mock
mock_session.execute.side_effect = [find_result, update_result]
procedure_data = ProcedureCreate(
name="deploy_api",
trigger_pattern="deploy.*api",
steps=[{"order": 1, "action": "deploy"}],
)
_ = await memory.record_procedure(procedure_data)
# Should have called execute twice (find + update)
assert mock_session.execute.call_count == 2
class TestProceduralMemoryFindMatching:
"""Tests for procedure matching methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
session.execute.return_value = mock_result
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_find_matching(
self,
memory: ProceduralMemory,
) -> None:
"""Test finding matching procedures."""
results = await memory.find_matching("deploy api")
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_find_matching_with_project_filter(
self,
memory: ProceduralMemory,
) -> None:
"""Test finding matching procedures with project filter."""
project_id = uuid4()
results = await memory.find_matching(
"deploy api",
project_id=project_id,
)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_find_matching_returns_results(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test that find_matching returns results."""
procedures = [
create_mock_procedure_model(name="deploy_api"),
create_mock_procedure_model(name="deploy_frontend"),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = procedures
mock_session.execute.return_value = mock_result
results = await memory.find_matching("deploy")
assert len(results) == 2
class TestProceduralMemoryGetBestProcedure:
"""Tests for get_best_procedure method."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_best_procedure_none(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test get_best_procedure returns None when no match."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
result = await memory.get_best_procedure("unknown_task")
assert result is None
@pytest.mark.asyncio
async def test_get_best_procedure_returns_highest_success_rate(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test get_best_procedure returns highest success rate."""
low_success = create_mock_procedure_model(
name="deploy_v1", success_count=3, failure_count=7
)
high_success = create_mock_procedure_model(
name="deploy_v2", success_count=9, failure_count=1
)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [high_success, low_success]
mock_session.execute.return_value = mock_result
result = await memory.get_best_procedure("deploy")
assert result is not None
assert result.name == "deploy_v2"
class TestProceduralMemoryRecordOutcome:
"""Tests for outcome recording."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_record_outcome_success(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test recording a successful outcome."""
existing_mock = create_mock_procedure_model()
# First query: find procedure
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
# Second query: update
updated_mock = create_mock_procedure_model(success_count=6)
update_result = MagicMock()
update_result.scalar_one.return_value = updated_mock
mock_session.execute.side_effect = [find_result, update_result]
result = await memory.record_outcome(existing_mock.id, success=True)
assert result.success_count == 6
@pytest.mark.asyncio
async def test_record_outcome_failure(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test recording a failure outcome."""
existing_mock = create_mock_procedure_model()
# First query: find procedure
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
# Second query: update
updated_mock = create_mock_procedure_model(failure_count=2)
update_result = MagicMock()
update_result.scalar_one.return_value = updated_mock
mock_session.execute.side_effect = [find_result, update_result]
result = await memory.record_outcome(existing_mock.id, success=False)
assert result.failure_count == 2
@pytest.mark.asyncio
async def test_record_outcome_not_found(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test recording outcome for non-existent procedure raises error."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
with pytest.raises(ValueError, match="Procedure not found"):
await memory.record_outcome(uuid4(), success=True)
class TestProceduralMemoryUpdateSteps:
"""Tests for step updates."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_update_steps(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test updating steps."""
existing_mock = create_mock_procedure_model()
# First query: find procedure
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
# Second query: update
updated_mock = create_mock_procedure_model()
updated_mock.steps = [
{"order": 1, "action": "new_step_1", "parameters": {}},
{"order": 2, "action": "new_step_2", "parameters": {}},
]
update_result = MagicMock()
update_result.scalar_one.return_value = updated_mock
mock_session.execute.side_effect = [find_result, update_result]
new_steps = [
Step(order=1, action="new_step_1"),
Step(order=2, action="new_step_2"),
]
result = await memory.update_steps(existing_mock.id, new_steps)
assert len(result.steps) == 2
@pytest.mark.asyncio
async def test_update_steps_not_found(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test updating steps for non-existent procedure raises error."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
with pytest.raises(ValueError, match="Procedure not found"):
await memory.update_steps(uuid4(), [Step(order=1, action="test")])
class TestProceduralMemoryStats:
"""Tests for statistics methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_stats_empty(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test getting stats for empty project."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
stats = await memory.get_stats(uuid4())
assert stats["total_procedures"] == 0
assert stats["avg_success_rate"] == 0.0
@pytest.mark.asyncio
async def test_get_stats_with_data(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test getting stats with data."""
procedures = [
create_mock_procedure_model(success_count=8, failure_count=2),
create_mock_procedure_model(success_count=6, failure_count=4),
]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = procedures
mock_session.execute.return_value = mock_result
stats = await memory.get_stats(uuid4())
assert stats["total_procedures"] == 2
assert stats["total_uses"] == 20 # (8+2) + (6+4)
@pytest.mark.asyncio
async def test_count(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test counting procedures."""
procedures = [create_mock_procedure_model() for _ in range(5)]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = procedures
mock_session.execute.return_value = mock_result
count = await memory.count(uuid4())
assert count == 5
class TestProceduralMemoryDelete:
"""Tests for delete operations."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_delete(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test deleting a procedure."""
existing_mock = create_mock_procedure_model()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_mock
mock_session.execute.return_value = mock_result
mock_session.delete = AsyncMock()
result = await memory.delete(existing_mock.id)
assert result is True
mock_session.delete.assert_called_once()
@pytest.mark.asyncio
async def test_delete_not_found(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test deleting non-existent procedure."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.delete(uuid4())
assert result is False
class TestProceduralMemoryGetById:
"""Tests for get_by_id method."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
"""Create a ProceduralMemory instance."""
return ProceduralMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_by_id(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test getting procedure by ID."""
existing_mock = create_mock_procedure_model()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_mock
mock_session.execute.return_value = mock_result
result = await memory.get_by_id(existing_mock.id)
assert result is not None
assert result.name == "deploy_api"
@pytest.mark.asyncio
async def test_get_by_id_not_found(
self,
memory: ProceduralMemory,
mock_session: AsyncMock,
) -> None:
"""Test get_by_id returns None when not found."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.get_by_id(uuid4())
assert result is None