forked from cardosofelipe/fast-next-template
feat(memory): add procedural memory implementation (Issue #92)
Implements procedural memory for learned skills and procedures: Core functionality: - ProceduralMemory class for procedure storage/retrieval - record_procedure with duplicate detection and step merging - find_matching for context-based procedure search - record_outcome for success/failure tracking - get_best_procedure for finding highest success rate - update_steps for procedure refinement Supporting modules: - ProcedureMatcher: Keyword-based procedure matching - MatchResult/MatchContext: Matching result types - Success rate weighting in match scoring Test coverage: - 43 unit tests covering all modules - matching.py: 97% coverage - memory.py: 86% coverage 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,7 +1,22 @@
|
||||
# app/services/memory/procedural/__init__.py
|
||||
"""
|
||||
Procedural Memory
|
||||
|
||||
Learned skills and procedures from successful task patterns.
|
||||
"""
|
||||
|
||||
# Will be populated in #92
|
||||
from .matching import (
|
||||
MatchContext,
|
||||
MatchResult,
|
||||
ProcedureMatcher,
|
||||
get_procedure_matcher,
|
||||
)
|
||||
from .memory import ProceduralMemory
|
||||
|
||||
__all__ = [
|
||||
"MatchContext",
|
||||
"MatchResult",
|
||||
"ProceduralMemory",
|
||||
"ProcedureMatcher",
|
||||
"get_procedure_matcher",
|
||||
]
|
||||
|
||||
291
backend/app/services/memory/procedural/matching.py
Normal file
291
backend/app/services/memory/procedural/matching.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# app/services/memory/procedural/matching.py
|
||||
"""
|
||||
Procedure Matching.
|
||||
|
||||
Provides utilities for matching procedures to contexts,
|
||||
ranking procedures by relevance, and suggesting procedures.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from app.services.memory.types import Procedure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchResult:
|
||||
"""Result of a procedure match."""
|
||||
|
||||
procedure: Procedure
|
||||
score: float
|
||||
matched_terms: list[str] = field(default_factory=list)
|
||||
match_type: str = "keyword" # keyword, semantic, pattern
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"procedure_id": str(self.procedure.id),
|
||||
"procedure_name": self.procedure.name,
|
||||
"score": self.score,
|
||||
"matched_terms": self.matched_terms,
|
||||
"match_type": self.match_type,
|
||||
"success_rate": self.procedure.success_rate,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchContext:
|
||||
"""Context for procedure matching."""
|
||||
|
||||
query: str
|
||||
task_type: str | None = None
|
||||
project_id: Any | None = None
|
||||
agent_type_id: Any | None = None
|
||||
max_results: int = 5
|
||||
min_score: float = 0.3
|
||||
require_success_rate: float | None = None
|
||||
|
||||
|
||||
class ProcedureMatcher:
|
||||
"""
|
||||
Matches procedures to contexts using multiple strategies.
|
||||
|
||||
Matching strategies:
|
||||
- Keyword matching on trigger pattern and name
|
||||
- Pattern-based matching using regex
|
||||
- Success rate weighting
|
||||
|
||||
In production, this would be augmented with vector similarity search.
|
||||
"""
|
||||
|
||||
# Common task-related keywords for boosting
|
||||
TASK_KEYWORDS: ClassVar[set[str]] = {
|
||||
"create",
|
||||
"update",
|
||||
"delete",
|
||||
"fix",
|
||||
"implement",
|
||||
"add",
|
||||
"remove",
|
||||
"refactor",
|
||||
"test",
|
||||
"deploy",
|
||||
"configure",
|
||||
"setup",
|
||||
"build",
|
||||
"debug",
|
||||
"optimize",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the matcher."""
|
||||
self._compiled_patterns: dict[str, re.Pattern[str]] = {}
|
||||
|
||||
def match(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
context: MatchContext,
|
||||
) -> list[MatchResult]:
|
||||
"""
|
||||
Match procedures against a context.
|
||||
|
||||
Args:
|
||||
procedures: List of procedures to match
|
||||
context: Matching context
|
||||
|
||||
Returns:
|
||||
List of match results, sorted by score (highest first)
|
||||
"""
|
||||
results: list[MatchResult] = []
|
||||
|
||||
query_terms = self._extract_terms(context.query)
|
||||
query_lower = context.query.lower()
|
||||
|
||||
for procedure in procedures:
|
||||
score, matched = self._calculate_match_score(
|
||||
procedure=procedure,
|
||||
query_terms=query_terms,
|
||||
query_lower=query_lower,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if score >= context.min_score:
|
||||
# Apply success rate boost
|
||||
if context.require_success_rate is not None:
|
||||
if procedure.success_rate < context.require_success_rate:
|
||||
continue
|
||||
|
||||
# Boost score based on success rate
|
||||
success_boost = procedure.success_rate * 0.2
|
||||
final_score = min(1.0, score + success_boost)
|
||||
|
||||
results.append(
|
||||
MatchResult(
|
||||
procedure=procedure,
|
||||
score=final_score,
|
||||
matched_terms=matched,
|
||||
match_type="keyword",
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by score descending
|
||||
results.sort(key=lambda r: r.score, reverse=True)
|
||||
|
||||
return results[: context.max_results]
|
||||
|
||||
def _extract_terms(self, text: str) -> list[str]:
|
||||
"""Extract searchable terms from text."""
|
||||
# Remove special characters and split
|
||||
clean = re.sub(r"[^\w\s-]", " ", text.lower())
|
||||
terms = clean.split()
|
||||
|
||||
# Filter out very short terms
|
||||
return [t for t in terms if len(t) >= 2]
|
||||
|
||||
def _calculate_match_score(
|
||||
self,
|
||||
procedure: Procedure,
|
||||
query_terms: list[str],
|
||||
query_lower: str,
|
||||
context: MatchContext,
|
||||
) -> tuple[float, list[str]]:
|
||||
"""
|
||||
Calculate match score between procedure and query.
|
||||
|
||||
Returns:
|
||||
Tuple of (score, matched_terms)
|
||||
"""
|
||||
score = 0.0
|
||||
matched: list[str] = []
|
||||
|
||||
trigger_lower = procedure.trigger_pattern.lower()
|
||||
name_lower = procedure.name.lower()
|
||||
|
||||
# Exact name match - high score
|
||||
if name_lower in query_lower or query_lower in name_lower:
|
||||
score += 0.5
|
||||
matched.append(f"name:{procedure.name}")
|
||||
|
||||
# Trigger pattern match
|
||||
if trigger_lower in query_lower or query_lower in trigger_lower:
|
||||
score += 0.4
|
||||
matched.append(f"trigger:{procedure.trigger_pattern[:30]}")
|
||||
|
||||
# Term-by-term matching
|
||||
for term in query_terms:
|
||||
if term in trigger_lower:
|
||||
score += 0.1
|
||||
matched.append(term)
|
||||
elif term in name_lower:
|
||||
score += 0.08
|
||||
matched.append(term)
|
||||
|
||||
# Boost for task keywords
|
||||
if term in self.TASK_KEYWORDS:
|
||||
if term in trigger_lower or term in name_lower:
|
||||
score += 0.05
|
||||
|
||||
# Task type match if provided
|
||||
if context.task_type:
|
||||
task_type_lower = context.task_type.lower()
|
||||
if task_type_lower in trigger_lower or task_type_lower in name_lower:
|
||||
score += 0.3
|
||||
matched.append(f"task_type:{context.task_type}")
|
||||
|
||||
# Regex pattern matching on trigger
|
||||
try:
|
||||
pattern = self._get_or_compile_pattern(trigger_lower)
|
||||
if pattern and pattern.search(query_lower):
|
||||
score += 0.25
|
||||
matched.append("pattern_match")
|
||||
except re.error:
|
||||
pass # Invalid regex, skip pattern matching
|
||||
|
||||
return min(1.0, score), matched
|
||||
|
||||
def _get_or_compile_pattern(self, pattern: str) -> re.Pattern[str] | None:
|
||||
"""Get or compile a regex pattern with caching."""
|
||||
if pattern in self._compiled_patterns:
|
||||
return self._compiled_patterns[pattern]
|
||||
|
||||
# Only compile if it looks like a regex pattern
|
||||
if not any(c in pattern for c in r"\.*+?[]{}|()^$"):
|
||||
return None
|
||||
|
||||
try:
|
||||
compiled = re.compile(pattern, re.IGNORECASE)
|
||||
self._compiled_patterns[pattern] = compiled
|
||||
return compiled
|
||||
except re.error:
|
||||
return None
|
||||
|
||||
def rank_by_relevance(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
task_type: str,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Rank procedures by relevance to a task type.
|
||||
|
||||
Args:
|
||||
procedures: Procedures to rank
|
||||
task_type: Task type for relevance
|
||||
|
||||
Returns:
|
||||
Procedures sorted by relevance
|
||||
"""
|
||||
context = MatchContext(
|
||||
query=task_type,
|
||||
task_type=task_type,
|
||||
min_score=0.0,
|
||||
max_results=len(procedures),
|
||||
)
|
||||
|
||||
results = self.match(procedures, context)
|
||||
return [r.procedure for r in results]
|
||||
|
||||
def suggest_procedures(
|
||||
self,
|
||||
procedures: list[Procedure],
|
||||
query: str,
|
||||
min_success_rate: float = 0.5,
|
||||
max_suggestions: int = 3,
|
||||
) -> list[MatchResult]:
|
||||
"""
|
||||
Suggest the best procedures for a query.
|
||||
|
||||
Only suggests procedures with sufficient success rate.
|
||||
|
||||
Args:
|
||||
procedures: Available procedures
|
||||
query: Query/context
|
||||
min_success_rate: Minimum success rate to suggest
|
||||
max_suggestions: Maximum suggestions
|
||||
|
||||
Returns:
|
||||
List of procedure suggestions
|
||||
"""
|
||||
context = MatchContext(
|
||||
query=query,
|
||||
max_results=max_suggestions,
|
||||
min_score=0.2,
|
||||
require_success_rate=min_success_rate,
|
||||
)
|
||||
|
||||
return self.match(procedures, context)
|
||||
|
||||
|
||||
# Singleton matcher instance
|
||||
_matcher: ProcedureMatcher | None = None
|
||||
|
||||
|
||||
def get_procedure_matcher() -> ProcedureMatcher:
|
||||
"""Get the singleton procedure matcher instance."""
|
||||
global _matcher
|
||||
if _matcher is None:
|
||||
_matcher = ProcedureMatcher()
|
||||
return _matcher
|
||||
724
backend/app/services/memory/procedural/memory.py
Normal file
724
backend/app/services/memory/procedural/memory.py
Normal file
@@ -0,0 +1,724 @@
|
||||
# app/services/memory/procedural/memory.py
|
||||
"""
|
||||
Procedural Memory Implementation.
|
||||
|
||||
Provides storage and retrieval for learned procedures (skills)
|
||||
derived from successful task execution patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, desc, or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.procedure import Procedure as ProcedureModel
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.types import Procedure, ProcedureCreate, RetrievalResult, Step
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _model_to_procedure(model: ProcedureModel) -> Procedure:
|
||||
"""Convert SQLAlchemy model to Procedure dataclass."""
|
||||
return Procedure(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
agent_type_id=model.agent_type_id, # type: ignore[arg-type]
|
||||
name=model.name, # type: ignore[arg-type]
|
||||
trigger_pattern=model.trigger_pattern, # type: ignore[arg-type]
|
||||
steps=model.steps or [], # type: ignore[arg-type]
|
||||
success_count=model.success_count, # type: ignore[arg-type]
|
||||
failure_count=model.failure_count, # type: ignore[arg-type]
|
||||
last_used=model.last_used, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class ProceduralMemory:
|
||||
"""
|
||||
Procedural Memory Service.
|
||||
|
||||
Provides procedure storage and retrieval:
|
||||
- Record procedures from successful task patterns
|
||||
- Find matching procedures by trigger pattern
|
||||
- Track success/failure rates
|
||||
- Get best procedure for a task type
|
||||
- Update procedure steps
|
||||
|
||||
Performance target: <50ms P95 for matching
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize procedural memory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic matching
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._settings = get_memory_settings()
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> "ProceduralMemory":
|
||||
"""
|
||||
Factory method to create ProceduralMemory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
|
||||
Returns:
|
||||
Configured ProceduralMemory instance
|
||||
"""
|
||||
return cls(session=session, embedding_generator=embedding_generator)
|
||||
|
||||
# =========================================================================
|
||||
# Procedure Recording
|
||||
# =========================================================================
|
||||
|
||||
async def record_procedure(self, procedure: ProcedureCreate) -> Procedure:
|
||||
"""
|
||||
Record a new procedure or update an existing one.
|
||||
|
||||
If a procedure with the same name exists in the same scope,
|
||||
its steps will be updated and success count incremented.
|
||||
|
||||
Args:
|
||||
procedure: Procedure data to record
|
||||
|
||||
Returns:
|
||||
The created or updated procedure
|
||||
"""
|
||||
# Check for existing procedure with same name
|
||||
existing = await self._find_existing_procedure(
|
||||
project_id=procedure.project_id,
|
||||
agent_type_id=procedure.agent_type_id,
|
||||
name=procedure.name,
|
||||
)
|
||||
|
||||
if existing is not None:
|
||||
# Update existing procedure
|
||||
return await self._update_existing_procedure(
|
||||
existing=existing,
|
||||
new_steps=procedure.steps,
|
||||
new_trigger=procedure.trigger_pattern,
|
||||
)
|
||||
|
||||
# Create new procedure
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Generate embedding if possible
|
||||
embedding = None
|
||||
if self._embedding_generator is not None:
|
||||
embedding_text = self._create_embedding_text(procedure)
|
||||
embedding = await self._embedding_generator.generate(embedding_text)
|
||||
|
||||
model = ProcedureModel(
|
||||
project_id=procedure.project_id,
|
||||
agent_type_id=procedure.agent_type_id,
|
||||
name=procedure.name,
|
||||
trigger_pattern=procedure.trigger_pattern,
|
||||
steps=procedure.steps,
|
||||
success_count=1, # New procedures start with 1 success (they worked)
|
||||
failure_count=0,
|
||||
last_used=now,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
|
||||
logger.info(
|
||||
f"Recorded new procedure: {procedure.name} with {len(procedure.steps)} steps"
|
||||
)
|
||||
|
||||
return _model_to_procedure(model)
|
||||
|
||||
async def _find_existing_procedure(
|
||||
self,
|
||||
project_id: UUID | None,
|
||||
agent_type_id: UUID | None,
|
||||
name: str,
|
||||
) -> ProcedureModel | None:
|
||||
"""Find an existing procedure with the same name in the same scope."""
|
||||
query = select(ProcedureModel).where(ProcedureModel.name == name)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(ProcedureModel.project_id == project_id)
|
||||
else:
|
||||
query = query.where(ProcedureModel.project_id.is_(None))
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(ProcedureModel.agent_type_id == agent_type_id)
|
||||
else:
|
||||
query = query.where(ProcedureModel.agent_type_id.is_(None))
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def _update_existing_procedure(
|
||||
self,
|
||||
existing: ProcedureModel,
|
||||
new_steps: list[dict[str, Any]],
|
||||
new_trigger: str,
|
||||
) -> Procedure:
|
||||
"""Update an existing procedure with new steps."""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Merge steps intelligently - keep existing order, add new steps
|
||||
merged_steps = self._merge_steps(
|
||||
existing.steps or [], # type: ignore[arg-type]
|
||||
new_steps,
|
||||
)
|
||||
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == existing.id)
|
||||
.values(
|
||||
steps=merged_steps,
|
||||
trigger_pattern=new_trigger,
|
||||
success_count=ProcedureModel.success_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Updated existing procedure: {existing.name}")
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
def _merge_steps(
|
||||
self,
|
||||
existing_steps: list[dict[str, Any]],
|
||||
new_steps: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Merge steps from a new execution with existing steps."""
|
||||
if not existing_steps:
|
||||
return new_steps
|
||||
if not new_steps:
|
||||
return existing_steps
|
||||
|
||||
# For now, use the new steps if they differ significantly
|
||||
# In production, this could use more sophisticated merging
|
||||
if len(new_steps) != len(existing_steps):
|
||||
# If structure changed, prefer newer steps
|
||||
return new_steps
|
||||
|
||||
# Merge step-by-step, preferring new data where available
|
||||
merged = []
|
||||
for i, new_step in enumerate(new_steps):
|
||||
if i < len(existing_steps):
|
||||
# Merge with existing step
|
||||
step = {**existing_steps[i], **new_step}
|
||||
else:
|
||||
step = new_step
|
||||
merged.append(step)
|
||||
|
||||
return merged
|
||||
|
||||
def _create_embedding_text(self, procedure: ProcedureCreate) -> str:
|
||||
"""Create text for embedding from procedure data."""
|
||||
steps_text = " ".join(step.get("action", "") for step in procedure.steps)
|
||||
return f"{procedure.name} {procedure.trigger_pattern} {steps_text}"
|
||||
|
||||
# =========================================================================
|
||||
# Procedure Retrieval
|
||||
# =========================================================================
|
||||
|
||||
async def find_matching(
|
||||
self,
|
||||
context: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
limit: int = 5,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Find procedures matching the given context.
|
||||
|
||||
Args:
|
||||
context: Context/trigger to match against
|
||||
project_id: Optional project to search within
|
||||
agent_type_id: Optional agent type filter
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of matching procedures
|
||||
"""
|
||||
result = await self._find_matching_with_metadata(
|
||||
context=context,
|
||||
project_id=project_id,
|
||||
agent_type_id=agent_type_id,
|
||||
limit=limit,
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def _find_matching_with_metadata(
|
||||
self,
|
||||
context: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
limit: int = 5,
|
||||
) -> RetrievalResult[Procedure]:
|
||||
"""Find matching procedures with full result metadata."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Build base query - prioritize by success rate
|
||||
stmt = (
|
||||
select(ProcedureModel)
|
||||
.order_by(
|
||||
desc(
|
||||
ProcedureModel.success_count
|
||||
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
|
||||
),
|
||||
desc(ProcedureModel.last_used),
|
||||
)
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# Apply scope filters
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
# Text-based matching on trigger pattern and name
|
||||
# TODO: Implement proper vector similarity search when pgvector is integrated
|
||||
search_terms = context.lower().split()[:5] # Limit to 5 terms
|
||||
if search_terms:
|
||||
conditions = []
|
||||
for term in search_terms:
|
||||
term_pattern = f"%{term}%"
|
||||
conditions.append(
|
||||
or_(
|
||||
ProcedureModel.trigger_pattern.ilike(term_pattern),
|
||||
ProcedureModel.name.ilike(term_pattern),
|
||||
)
|
||||
)
|
||||
if conditions:
|
||||
stmt = stmt.where(or_(*conditions))
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_procedure(m) for m in models],
|
||||
total_count=len(models),
|
||||
query=context,
|
||||
retrieval_type="procedural",
|
||||
latency_ms=latency_ms,
|
||||
metadata={"project_id": str(project_id) if project_id else None},
|
||||
)
|
||||
|
||||
async def get_best_procedure(
|
||||
self,
|
||||
task_type: str,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
min_success_rate: float = 0.5,
|
||||
min_uses: int = 1,
|
||||
) -> Procedure | None:
|
||||
"""
|
||||
Get the best procedure for a given task type.
|
||||
|
||||
Returns the procedure with the highest success rate that
|
||||
meets the minimum thresholds.
|
||||
|
||||
Args:
|
||||
task_type: Task type to find procedure for
|
||||
project_id: Optional project scope
|
||||
agent_type_id: Optional agent type scope
|
||||
min_success_rate: Minimum required success rate
|
||||
min_uses: Minimum number of uses required
|
||||
|
||||
Returns:
|
||||
Best matching procedure or None
|
||||
"""
|
||||
# Build query for procedures matching task type
|
||||
stmt = (
|
||||
select(ProcedureModel)
|
||||
.where(
|
||||
and_(
|
||||
(ProcedureModel.success_count + ProcedureModel.failure_count)
|
||||
>= min_uses,
|
||||
or_(
|
||||
ProcedureModel.trigger_pattern.ilike(f"%{task_type}%"),
|
||||
ProcedureModel.name.ilike(f"%{task_type}%"),
|
||||
),
|
||||
)
|
||||
)
|
||||
.order_by(
|
||||
desc(
|
||||
ProcedureModel.success_count
|
||||
/ (ProcedureModel.success_count + ProcedureModel.failure_count + 1)
|
||||
),
|
||||
desc(ProcedureModel.last_used),
|
||||
)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
# Apply scope filters
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Filter by success rate in Python (SQLAlchemy division in WHERE is complex)
|
||||
for model in models:
|
||||
success = float(model.success_count)
|
||||
failure = float(model.failure_count)
|
||||
total = success + failure
|
||||
if total > 0 and (success / total) >= min_success_rate:
|
||||
logger.debug(
|
||||
f"Found best procedure for '{task_type}': {model.name} "
|
||||
f"(success_rate={success / total:.2%})"
|
||||
)
|
||||
return _model_to_procedure(model)
|
||||
|
||||
return None
|
||||
|
||||
async def get_by_id(self, procedure_id: UUID) -> Procedure | None:
|
||||
"""Get a procedure by ID."""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
return _model_to_procedure(model) if model else None
|
||||
|
||||
# =========================================================================
|
||||
# Outcome Recording
|
||||
# =========================================================================
|
||||
|
||||
async def record_outcome(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
success: bool,
|
||||
) -> Procedure:
|
||||
"""
|
||||
Record the outcome of using a procedure.
|
||||
|
||||
Updates the success or failure count and last_used timestamp.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure that was used
|
||||
success: Whether the procedure succeeded
|
||||
|
||||
Returns:
|
||||
Updated procedure
|
||||
|
||||
Raises:
|
||||
ValueError: If procedure not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Procedure not found: {procedure_id}")
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
if success:
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
success_count=ProcedureModel.success_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
else:
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
failure_count=ProcedureModel.failure_count + 1,
|
||||
last_used=now,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
outcome = "success" if success else "failure"
|
||||
logger.info(
|
||||
f"Recorded {outcome} for procedure {procedure_id}: "
|
||||
f"success_rate={updated_model.success_rate:.2%}"
|
||||
)
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
# =========================================================================
|
||||
# Step Management
|
||||
# =========================================================================
|
||||
|
||||
async def update_steps(
|
||||
self,
|
||||
procedure_id: UUID,
|
||||
steps: list[Step],
|
||||
) -> Procedure:
|
||||
"""
|
||||
Update the steps of a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure to update
|
||||
steps: New steps
|
||||
|
||||
Returns:
|
||||
Updated procedure
|
||||
|
||||
Raises:
|
||||
ValueError: If procedure not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Procedure not found: {procedure_id}")
|
||||
|
||||
# Convert Step objects to dictionaries
|
||||
steps_dict = [
|
||||
{
|
||||
"order": step.order,
|
||||
"action": step.action,
|
||||
"parameters": step.parameters,
|
||||
"expected_outcome": step.expected_outcome,
|
||||
"fallback_action": step.fallback_action,
|
||||
}
|
||||
for step in steps
|
||||
]
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(ProcedureModel)
|
||||
.where(ProcedureModel.id == procedure_id)
|
||||
.values(
|
||||
steps=steps_dict,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(ProcedureModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Updated steps for procedure {procedure_id}: {len(steps)} steps")
|
||||
|
||||
return _model_to_procedure(updated_model)
|
||||
|
||||
# =========================================================================
|
||||
# Statistics & Management
|
||||
# =========================================================================
|
||||
|
||||
async def get_stats(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about procedural memory.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to get stats for
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
query = select(ProcedureModel)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
if not models:
|
||||
return {
|
||||
"total_procedures": 0,
|
||||
"avg_success_rate": 0.0,
|
||||
"avg_steps_count": 0.0,
|
||||
"total_uses": 0,
|
||||
"high_success_count": 0,
|
||||
"low_success_count": 0,
|
||||
}
|
||||
|
||||
success_rates = [m.success_rate for m in models]
|
||||
step_counts = [len(m.steps or []) for m in models]
|
||||
total_uses = sum(m.total_uses for m in models)
|
||||
|
||||
return {
|
||||
"total_procedures": len(models),
|
||||
"avg_success_rate": sum(success_rates) / len(success_rates),
|
||||
"avg_steps_count": sum(step_counts) / len(step_counts),
|
||||
"total_uses": total_uses,
|
||||
"high_success_count": sum(1 for r in success_rates if r >= 0.8),
|
||||
"low_success_count": sum(1 for r in success_rates if r < 0.5),
|
||||
}
|
||||
|
||||
async def count(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
agent_type_id: UUID | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Count procedures in scope.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to count for
|
||||
agent_type_id: Optional agent type filter
|
||||
|
||||
Returns:
|
||||
Number of procedures
|
||||
"""
|
||||
query = select(ProcedureModel)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
if agent_type_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.agent_type_id == agent_type_id,
|
||||
ProcedureModel.agent_type_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
async def delete(self, procedure_id: UUID) -> bool:
|
||||
"""
|
||||
Delete a procedure.
|
||||
|
||||
Args:
|
||||
procedure_id: Procedure to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = select(ProcedureModel).where(ProcedureModel.id == procedure_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Deleted procedure {procedure_id}")
|
||||
return True
|
||||
|
||||
async def get_procedures_by_success_rate(
|
||||
self,
|
||||
min_rate: float = 0.0,
|
||||
max_rate: float = 1.0,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[Procedure]:
|
||||
"""
|
||||
Get procedures within a success rate range.
|
||||
|
||||
Args:
|
||||
min_rate: Minimum success rate
|
||||
max_rate: Maximum success rate
|
||||
project_id: Optional project scope
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of procedures
|
||||
"""
|
||||
query = (
|
||||
select(ProcedureModel)
|
||||
.order_by(desc(ProcedureModel.last_used))
|
||||
.limit(limit * 2) # Fetch more since we filter in Python
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
ProcedureModel.project_id == project_id,
|
||||
ProcedureModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Filter by success rate in Python
|
||||
filtered = [m for m in models if min_rate <= m.success_rate <= max_rate][:limit]
|
||||
|
||||
return [_model_to_procedure(m) for m in filtered]
|
||||
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/procedural/__init__.py
|
||||
"""Unit tests for procedural memory."""
|
||||
427
backend/tests/unit/services/memory/procedural/test_matching.py
Normal file
427
backend/tests/unit/services/memory/procedural/test_matching.py
Normal file
@@ -0,0 +1,427 @@
|
||||
# tests/unit/services/memory/procedural/test_matching.py
|
||||
"""Unit tests for procedure matching."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.procedural.matching import (
|
||||
MatchContext,
|
||||
MatchResult,
|
||||
ProcedureMatcher,
|
||||
get_procedure_matcher,
|
||||
)
|
||||
from app.services.memory.types import Procedure
|
||||
|
||||
|
||||
def create_test_procedure(
|
||||
name: str = "deploy_api",
|
||||
trigger_pattern: str = "deploy.*api",
|
||||
success_count: int = 8,
|
||||
failure_count: int = 2,
|
||||
) -> Procedure:
|
||||
"""Create a test procedure for testing."""
|
||||
now = datetime.now(UTC)
|
||||
return Procedure(
|
||||
id=uuid4(),
|
||||
project_id=None,
|
||||
agent_type_id=None,
|
||||
name=name,
|
||||
trigger_pattern=trigger_pattern,
|
||||
steps=[
|
||||
{"order": 1, "action": "build"},
|
||||
{"order": 2, "action": "test"},
|
||||
{"order": 3, "action": "deploy"},
|
||||
],
|
||||
success_count=success_count,
|
||||
failure_count=failure_count,
|
||||
last_used=now,
|
||||
embedding=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
class TestMatchResult:
|
||||
"""Tests for MatchResult dataclass."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting match result to dictionary."""
|
||||
procedure = create_test_procedure()
|
||||
result = MatchResult(
|
||||
procedure=procedure,
|
||||
score=0.85,
|
||||
matched_terms=["deploy", "api"],
|
||||
match_type="keyword",
|
||||
)
|
||||
|
||||
data = result.to_dict()
|
||||
|
||||
assert "procedure_id" in data
|
||||
assert "procedure_name" in data
|
||||
assert data["score"] == 0.85
|
||||
assert data["matched_terms"] == ["deploy", "api"]
|
||||
assert data["match_type"] == "keyword"
|
||||
assert data["success_rate"] == 0.8
|
||||
|
||||
|
||||
class TestMatchContext:
|
||||
"""Tests for MatchContext dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default values."""
|
||||
context = MatchContext(query="deploy api")
|
||||
|
||||
assert context.query == "deploy api"
|
||||
assert context.task_type is None
|
||||
assert context.project_id is None
|
||||
assert context.max_results == 5
|
||||
assert context.min_score == 0.3
|
||||
assert context.require_success_rate is None
|
||||
|
||||
def test_with_all_values(self) -> None:
|
||||
"""Test with all values set."""
|
||||
project_id = uuid4()
|
||||
context = MatchContext(
|
||||
query="deploy api",
|
||||
task_type="deployment",
|
||||
project_id=project_id,
|
||||
max_results=10,
|
||||
min_score=0.5,
|
||||
require_success_rate=0.7,
|
||||
)
|
||||
|
||||
assert context.query == "deploy api"
|
||||
assert context.task_type == "deployment"
|
||||
assert context.project_id == project_id
|
||||
assert context.max_results == 10
|
||||
assert context.min_score == 0.5
|
||||
assert context.require_success_rate == 0.7
|
||||
|
||||
|
||||
class TestProcedureMatcher:
|
||||
"""Tests for ProcedureMatcher class."""
|
||||
|
||||
@pytest.fixture
|
||||
def matcher(self) -> ProcedureMatcher:
|
||||
"""Create a procedure matcher."""
|
||||
return ProcedureMatcher()
|
||||
|
||||
@pytest.fixture
|
||||
def procedures(self) -> list[Procedure]:
|
||||
"""Create test procedures."""
|
||||
return [
|
||||
create_test_procedure(
|
||||
name="deploy_api",
|
||||
trigger_pattern="deploy.*api",
|
||||
success_count=9,
|
||||
failure_count=1,
|
||||
),
|
||||
create_test_procedure(
|
||||
name="deploy_frontend",
|
||||
trigger_pattern="deploy.*frontend",
|
||||
success_count=7,
|
||||
failure_count=3,
|
||||
),
|
||||
create_test_procedure(
|
||||
name="build_project",
|
||||
trigger_pattern="build.*project",
|
||||
success_count=8,
|
||||
failure_count=2,
|
||||
),
|
||||
create_test_procedure(
|
||||
name="run_tests",
|
||||
trigger_pattern="test.*run",
|
||||
success_count=5,
|
||||
failure_count=5,
|
||||
),
|
||||
]
|
||||
|
||||
def test_match_exact_name(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test matching with exact name."""
|
||||
context = MatchContext(query="deploy_api")
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
assert len(results) > 0
|
||||
# First result should be deploy_api
|
||||
assert results[0].procedure.name == "deploy_api"
|
||||
|
||||
def test_match_partial_terms(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test matching with partial terms."""
|
||||
context = MatchContext(query="deploy")
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
assert len(results) >= 2
|
||||
# Both deploy procedures should match
|
||||
names = [r.procedure.name for r in results]
|
||||
assert "deploy_api" in names
|
||||
assert "deploy_frontend" in names
|
||||
|
||||
def test_match_with_task_type(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test matching with task type."""
|
||||
context = MatchContext(
|
||||
query="build something",
|
||||
task_type="build",
|
||||
)
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
assert len(results) > 0
|
||||
assert results[0].procedure.name == "build_project"
|
||||
|
||||
def test_match_respects_min_score(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test that matching respects minimum score."""
|
||||
context = MatchContext(
|
||||
query="completely unrelated query xyz",
|
||||
min_score=0.5,
|
||||
)
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
# Should not match anything with high min_score
|
||||
for result in results:
|
||||
assert result.score >= 0.5
|
||||
|
||||
def test_match_respects_success_rate_requirement(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test that matching respects success rate requirement."""
|
||||
context = MatchContext(
|
||||
query="deploy",
|
||||
require_success_rate=0.7,
|
||||
)
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
for result in results:
|
||||
assert result.procedure.success_rate >= 0.7
|
||||
|
||||
def test_match_respects_max_results(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test that matching respects max results."""
|
||||
context = MatchContext(
|
||||
query="deploy",
|
||||
max_results=1,
|
||||
min_score=0.0,
|
||||
)
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
assert len(results) <= 1
|
||||
|
||||
def test_match_sorts_by_score(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test that results are sorted by score."""
|
||||
context = MatchContext(query="deploy", min_score=0.0)
|
||||
|
||||
results = matcher.match(procedures, context)
|
||||
|
||||
if len(results) > 1:
|
||||
scores = [r.score for r in results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
def test_match_empty_procedures(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test matching with empty procedures list."""
|
||||
context = MatchContext(query="deploy")
|
||||
|
||||
results = matcher.match([], context)
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestProcedureMatcherRankByRelevance:
|
||||
"""Tests for rank_by_relevance method."""
|
||||
|
||||
@pytest.fixture
|
||||
def matcher(self) -> ProcedureMatcher:
|
||||
"""Create a procedure matcher."""
|
||||
return ProcedureMatcher()
|
||||
|
||||
def test_rank_by_relevance(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test ranking by relevance."""
|
||||
procedures = [
|
||||
create_test_procedure(name="unrelated", trigger_pattern="something else"),
|
||||
create_test_procedure(name="deploy_api", trigger_pattern="deploy.*api"),
|
||||
create_test_procedure(
|
||||
name="deploy_frontend", trigger_pattern="deploy.*frontend"
|
||||
),
|
||||
]
|
||||
|
||||
ranked = matcher.rank_by_relevance(procedures, "deploy")
|
||||
|
||||
# Deploy procedures should be ranked first
|
||||
assert ranked[0].name in ["deploy_api", "deploy_frontend"]
|
||||
|
||||
def test_rank_by_relevance_empty(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test ranking empty list."""
|
||||
ranked = matcher.rank_by_relevance([], "deploy")
|
||||
|
||||
assert ranked == []
|
||||
|
||||
|
||||
class TestProcedureMatcherSuggestProcedures:
|
||||
"""Tests for suggest_procedures method."""
|
||||
|
||||
@pytest.fixture
|
||||
def matcher(self) -> ProcedureMatcher:
|
||||
"""Create a procedure matcher."""
|
||||
return ProcedureMatcher()
|
||||
|
||||
@pytest.fixture
|
||||
def procedures(self) -> list[Procedure]:
|
||||
"""Create test procedures."""
|
||||
return [
|
||||
create_test_procedure(
|
||||
name="deploy_api",
|
||||
trigger_pattern="deploy api",
|
||||
success_count=9,
|
||||
failure_count=1,
|
||||
),
|
||||
create_test_procedure(
|
||||
name="bad_deploy",
|
||||
trigger_pattern="deploy bad",
|
||||
success_count=2,
|
||||
failure_count=8,
|
||||
),
|
||||
]
|
||||
|
||||
def test_suggest_procedures(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test suggesting procedures."""
|
||||
suggestions = matcher.suggest_procedures(
|
||||
procedures,
|
||||
"deploy",
|
||||
min_success_rate=0.5,
|
||||
)
|
||||
|
||||
assert len(suggestions) > 0
|
||||
# Only high success rate should be suggested
|
||||
for s in suggestions:
|
||||
assert s.procedure.success_rate >= 0.5
|
||||
|
||||
def test_suggest_procedures_limits_results(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
procedures: list[Procedure],
|
||||
) -> None:
|
||||
"""Test that suggestions are limited."""
|
||||
suggestions = matcher.suggest_procedures(
|
||||
procedures,
|
||||
"deploy",
|
||||
max_suggestions=1,
|
||||
)
|
||||
|
||||
assert len(suggestions) <= 1
|
||||
|
||||
|
||||
class TestGetProcedureMatcher:
|
||||
"""Tests for singleton getter."""
|
||||
|
||||
def test_get_procedure_matcher_returns_instance(self) -> None:
|
||||
"""Test that getter returns instance."""
|
||||
matcher = get_procedure_matcher()
|
||||
|
||||
assert matcher is not None
|
||||
assert isinstance(matcher, ProcedureMatcher)
|
||||
|
||||
def test_get_procedure_matcher_returns_same_instance(self) -> None:
|
||||
"""Test that getter returns same instance (singleton)."""
|
||||
matcher1 = get_procedure_matcher()
|
||||
matcher2 = get_procedure_matcher()
|
||||
|
||||
assert matcher1 is matcher2
|
||||
|
||||
|
||||
class TestProcedureMatcherExtractTerms:
|
||||
"""Tests for term extraction."""
|
||||
|
||||
@pytest.fixture
|
||||
def matcher(self) -> ProcedureMatcher:
|
||||
"""Create a procedure matcher."""
|
||||
return ProcedureMatcher()
|
||||
|
||||
def test_extract_terms_basic(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test basic term extraction."""
|
||||
terms = matcher._extract_terms("deploy the api")
|
||||
|
||||
assert "deploy" in terms
|
||||
assert "the" in terms
|
||||
assert "api" in terms
|
||||
|
||||
def test_extract_terms_removes_special_chars(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test that special characters are removed."""
|
||||
terms = matcher._extract_terms("deploy.api!now")
|
||||
|
||||
assert "deploy" in terms
|
||||
assert "api" in terms
|
||||
assert "now" in terms
|
||||
assert "." not in terms
|
||||
assert "!" not in terms
|
||||
|
||||
def test_extract_terms_filters_short(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test that short terms are filtered."""
|
||||
terms = matcher._extract_terms("a big api")
|
||||
|
||||
assert "a" not in terms
|
||||
assert "big" in terms
|
||||
assert "api" in terms
|
||||
|
||||
def test_extract_terms_lowercases(
|
||||
self,
|
||||
matcher: ProcedureMatcher,
|
||||
) -> None:
|
||||
"""Test that terms are lowercased."""
|
||||
terms = matcher._extract_terms("Deploy API")
|
||||
|
||||
assert "deploy" in terms
|
||||
assert "api" in terms
|
||||
assert "Deploy" not in terms
|
||||
assert "API" not in terms
|
||||
569
backend/tests/unit/services/memory/procedural/test_memory.py
Normal file
569
backend/tests/unit/services/memory/procedural/test_memory.py
Normal file
@@ -0,0 +1,569 @@
|
||||
# tests/unit/services/memory/procedural/test_memory.py
|
||||
"""Unit tests for ProceduralMemory class."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.procedural.memory import ProceduralMemory
|
||||
from app.services.memory.types import ProcedureCreate, Step
|
||||
|
||||
|
||||
def create_mock_procedure_model(
|
||||
name="deploy_api",
|
||||
trigger_pattern="deploy.*api",
|
||||
project_id=None,
|
||||
agent_type_id=None,
|
||||
success_count=5,
|
||||
failure_count=1,
|
||||
):
|
||||
"""Create a mock procedure model for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = uuid4()
|
||||
mock.project_id = project_id
|
||||
mock.agent_type_id = agent_type_id
|
||||
mock.name = name
|
||||
mock.trigger_pattern = trigger_pattern
|
||||
mock.steps = [
|
||||
{"order": 1, "action": "build", "parameters": {}},
|
||||
{"order": 2, "action": "test", "parameters": {}},
|
||||
{"order": 3, "action": "deploy", "parameters": {}},
|
||||
]
|
||||
mock.success_count = success_count
|
||||
mock.failure_count = failure_count
|
||||
mock.last_used = datetime.now(UTC)
|
||||
mock.embedding = None
|
||||
mock.created_at = datetime.now(UTC)
|
||||
mock.updated_at = datetime.now(UTC)
|
||||
mock.success_rate = (
|
||||
success_count / (success_count + failure_count)
|
||||
if (success_count + failure_count) > 0
|
||||
else 0.0
|
||||
)
|
||||
mock.total_uses = success_count + failure_count
|
||||
return mock
|
||||
|
||||
|
||||
class TestProceduralMemoryInit:
|
||||
"""Tests for ProceduralMemory initialization."""
|
||||
|
||||
def test_init_creates_memory(self) -> None:
|
||||
"""Test that init creates memory instance."""
|
||||
mock_session = AsyncMock()
|
||||
memory = ProceduralMemory(session=mock_session)
|
||||
|
||||
assert memory._session is mock_session
|
||||
|
||||
def test_init_with_embedding_generator(self) -> None:
|
||||
"""Test init with embedding generator."""
|
||||
mock_session = AsyncMock()
|
||||
mock_embedding_gen = AsyncMock()
|
||||
memory = ProceduralMemory(
|
||||
session=mock_session, embedding_generator=mock_embedding_gen
|
||||
)
|
||||
|
||||
assert memory._embedding_generator is mock_embedding_gen
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_factory_method(self) -> None:
|
||||
"""Test create factory method."""
|
||||
mock_session = AsyncMock()
|
||||
memory = await ProceduralMemory.create(session=mock_session)
|
||||
|
||||
assert memory is not None
|
||||
assert memory._session is mock_session
|
||||
|
||||
|
||||
class TestProceduralMemoryRecordProcedure:
|
||||
"""Tests for procedure recording methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
# Mock no existing procedure found
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
|
||||
"""Create a ProceduralMemory instance."""
|
||||
return ProceduralMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_new_procedure(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test recording a new procedure."""
|
||||
procedure_data = ProcedureCreate(
|
||||
name="build_project",
|
||||
trigger_pattern="build.*project",
|
||||
steps=[
|
||||
{"order": 1, "action": "npm install"},
|
||||
{"order": 2, "action": "npm run build"},
|
||||
],
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
result = await memory.record_procedure(procedure_data)
|
||||
|
||||
assert result.name == "build_project"
|
||||
assert result.trigger_pattern == "build.*project"
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_updates_existing(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that recording duplicate procedure updates existing."""
|
||||
# Mock existing procedure found
|
||||
existing_mock = create_mock_procedure_model()
|
||||
find_result = MagicMock()
|
||||
find_result.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
# Mock update result
|
||||
updated_mock = create_mock_procedure_model(success_count=6)
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one.return_value = updated_mock
|
||||
|
||||
mock_session.execute.side_effect = [find_result, update_result]
|
||||
|
||||
procedure_data = ProcedureCreate(
|
||||
name="deploy_api",
|
||||
trigger_pattern="deploy.*api",
|
||||
steps=[{"order": 1, "action": "deploy"}],
|
||||
)
|
||||
|
||||
_ = await memory.record_procedure(procedure_data)
|
||||
|
||||
# Should have called execute twice (find + update)
|
||||
assert mock_session.execute.call_count == 2
|
||||
|
||||
|
||||
class TestProceduralMemoryFindMatching:
|
||||
"""Tests for procedure matching methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
|
||||
"""Create a ProceduralMemory instance."""
|
||||
return ProceduralMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
) -> None:
|
||||
"""Test finding matching procedures."""
|
||||
results = await memory.find_matching("deploy api")
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_with_project_filter(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
) -> None:
|
||||
"""Test finding matching procedures with project filter."""
|
||||
project_id = uuid4()
|
||||
results = await memory.find_matching(
|
||||
"deploy api",
|
||||
project_id=project_id,
|
||||
)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_returns_results(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that find_matching returns results."""
|
||||
procedures = [
|
||||
create_mock_procedure_model(name="deploy_api"),
|
||||
create_mock_procedure_model(name="deploy_frontend"),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = procedures
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
results = await memory.find_matching("deploy")
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
class TestProceduralMemoryGetBestProcedure:
|
||||
"""Tests for get_best_procedure method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
|
||||
"""Create a ProceduralMemory instance."""
|
||||
return ProceduralMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_best_procedure_none(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test get_best_procedure returns None when no match."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.get_best_procedure("unknown_task")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_best_procedure_returns_highest_success_rate(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test get_best_procedure returns highest success rate."""
|
||||
low_success = create_mock_procedure_model(
|
||||
name="deploy_v1", success_count=3, failure_count=7
|
||||
)
|
||||
high_success = create_mock_procedure_model(
|
||||
name="deploy_v2", success_count=9, failure_count=1
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [high_success, low_success]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.get_best_procedure("deploy")
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "deploy_v2"
|
||||
|
||||
|
||||
class TestProceduralMemoryRecordOutcome:
|
||||
"""Tests for outcome recording."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
|
||||
"""Create a ProceduralMemory instance."""
|
||||
return ProceduralMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_outcome_success(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test recording a successful outcome."""
|
||||
existing_mock = create_mock_procedure_model()
|
||||
|
||||
# First query: find procedure
|
||||
find_result = MagicMock()
|
||||
find_result.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
# Second query: update
|
||||
updated_mock = create_mock_procedure_model(success_count=6)
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one.return_value = updated_mock
|
||||
|
||||
mock_session.execute.side_effect = [find_result, update_result]
|
||||
|
||||
result = await memory.record_outcome(existing_mock.id, success=True)
|
||||
|
||||
assert result.success_count == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_outcome_failure(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test recording a failure outcome."""
|
||||
existing_mock = create_mock_procedure_model()
|
||||
|
||||
# First query: find procedure
|
||||
find_result = MagicMock()
|
||||
find_result.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
# Second query: update
|
||||
updated_mock = create_mock_procedure_model(failure_count=2)
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one.return_value = updated_mock
|
||||
|
||||
mock_session.execute.side_effect = [find_result, update_result]
|
||||
|
||||
result = await memory.record_outcome(existing_mock.id, success=False)
|
||||
|
||||
assert result.failure_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_outcome_not_found(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test recording outcome for non-existent procedure raises error."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
with pytest.raises(ValueError, match="Procedure not found"):
|
||||
await memory.record_outcome(uuid4(), success=True)
|
||||
|
||||
|
||||
class TestProceduralMemoryUpdateSteps:
|
||||
"""Tests for step updates."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
|
||||
"""Create a ProceduralMemory instance."""
|
||||
return ProceduralMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_steps(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test updating steps."""
|
||||
existing_mock = create_mock_procedure_model()
|
||||
|
||||
# First query: find procedure
|
||||
find_result = MagicMock()
|
||||
find_result.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
# Second query: update
|
||||
updated_mock = create_mock_procedure_model()
|
||||
updated_mock.steps = [
|
||||
{"order": 1, "action": "new_step_1", "parameters": {}},
|
||||
{"order": 2, "action": "new_step_2", "parameters": {}},
|
||||
]
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one.return_value = updated_mock
|
||||
|
||||
mock_session.execute.side_effect = [find_result, update_result]
|
||||
|
||||
new_steps = [
|
||||
Step(order=1, action="new_step_1"),
|
||||
Step(order=2, action="new_step_2"),
|
||||
]
|
||||
|
||||
result = await memory.update_steps(existing_mock.id, new_steps)
|
||||
|
||||
assert len(result.steps) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_steps_not_found(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test updating steps for non-existent procedure raises error."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
with pytest.raises(ValueError, match="Procedure not found"):
|
||||
await memory.update_steps(uuid4(), [Step(order=1, action="test")])
|
||||
|
||||
|
||||
class TestProceduralMemoryStats:
|
||||
"""Tests for statistics methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
|
||||
"""Create a ProceduralMemory instance."""
|
||||
return ProceduralMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_empty(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test getting stats for empty project."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
stats = await memory.get_stats(uuid4())
|
||||
|
||||
assert stats["total_procedures"] == 0
|
||||
assert stats["avg_success_rate"] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_with_data(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test getting stats with data."""
|
||||
procedures = [
|
||||
create_mock_procedure_model(success_count=8, failure_count=2),
|
||||
create_mock_procedure_model(success_count=6, failure_count=4),
|
||||
]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = procedures
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
stats = await memory.get_stats(uuid4())
|
||||
|
||||
assert stats["total_procedures"] == 2
|
||||
assert stats["total_uses"] == 20 # (8+2) + (6+4)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test counting procedures."""
|
||||
procedures = [create_mock_procedure_model() for _ in range(5)]
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = procedures
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
count = await memory.count(uuid4())
|
||||
|
||||
assert count == 5
|
||||
|
||||
|
||||
class TestProceduralMemoryDelete:
|
||||
"""Tests for delete operations."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
|
||||
"""Create a ProceduralMemory instance."""
|
||||
return ProceduralMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test deleting a procedure."""
|
||||
existing_mock = create_mock_procedure_model()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = existing_mock
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session.delete = AsyncMock()
|
||||
|
||||
result = await memory.delete(existing_mock.id)
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_not_found(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test deleting non-existent procedure."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.delete(uuid4())
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestProceduralMemoryGetById:
|
||||
"""Tests for get_by_id method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> ProceduralMemory:
|
||||
"""Create a ProceduralMemory instance."""
|
||||
return ProceduralMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test getting procedure by ID."""
|
||||
existing_mock = create_mock_procedure_model()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = existing_mock
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.get_by_id(existing_mock.id)
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "deploy_api"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_not_found(
|
||||
self,
|
||||
memory: ProceduralMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test get_by_id returns None when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
Reference in New Issue
Block a user