From b818f174184689a261f90b17311953dbf7809445 Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Mon, 5 Jan 2026 02:31:32 +0100 Subject: [PATCH] feat(memory): add procedural memory implementation (Issue #92) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements procedural memory for learned skills and procedures: Core functionality: - ProceduralMemory class for procedure storage/retrieval - record_procedure with duplicate detection and step merging - find_matching for context-based procedure search - record_outcome for success/failure tracking - get_best_procedure for finding highest success rate - update_steps for procedure refinement Supporting modules: - ProcedureMatcher: Keyword-based procedure matching - MatchResult/MatchContext: Matching result types - Success rate weighting in match scoring Test coverage: - 43 unit tests covering all modules - matching.py: 97% coverage - memory.py: 86% coverage 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../services/memory/procedural/__init__.py | 17 +- .../services/memory/procedural/matching.py | 291 +++++++ .../app/services/memory/procedural/memory.py | 724 ++++++++++++++++++ .../services/memory/procedural/__init__.py | 2 + .../memory/procedural/test_matching.py | 427 +++++++++++ .../services/memory/procedural/test_memory.py | 569 ++++++++++++++ 6 files changed, 2029 insertions(+), 1 deletion(-) create mode 100644 backend/app/services/memory/procedural/matching.py create mode 100644 backend/app/services/memory/procedural/memory.py create mode 100644 backend/tests/unit/services/memory/procedural/__init__.py create mode 100644 backend/tests/unit/services/memory/procedural/test_matching.py create mode 100644 backend/tests/unit/services/memory/procedural/test_memory.py diff --git a/backend/app/services/memory/procedural/__init__.py b/backend/app/services/memory/procedural/__init__.py index 29d5131..957b558 100644 --- a/backend/app/services/memory/procedural/__init__.py +++ b/backend/app/services/memory/procedural/__init__.py @@ -1,7 +1,22 @@ +# app/services/memory/procedural/__init__.py """ Procedural Memory Learned skills and procedures from successful task patterns. """ -# Will be populated in #92 +from .matching import ( + MatchContext, + MatchResult, + ProcedureMatcher, + get_procedure_matcher, +) +from .memory import ProceduralMemory + +__all__ = [ + "MatchContext", + "MatchResult", + "ProceduralMemory", + "ProcedureMatcher", + "get_procedure_matcher", +] diff --git a/backend/app/services/memory/procedural/matching.py b/backend/app/services/memory/procedural/matching.py new file mode 100644 index 0000000..17e835d --- /dev/null +++ b/backend/app/services/memory/procedural/matching.py @@ -0,0 +1,291 @@ +# app/services/memory/procedural/matching.py +""" +Procedure Matching. + +Provides utilities for matching procedures to contexts, +ranking procedures by relevance, and suggesting procedures. +""" + +import logging +import re +from dataclasses import dataclass, field +from typing import Any, ClassVar + +from app.services.memory.types import Procedure + +logger = logging.getLogger(__name__) + + +@dataclass +class MatchResult: + """Result of a procedure match.""" + + procedure: Procedure + score: float + matched_terms: list[str] = field(default_factory=list) + match_type: str = "keyword" # keyword, semantic, pattern + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "procedure_id": str(self.procedure.id), + "procedure_name": self.procedure.name, + "score": self.score, + "matched_terms": self.matched_terms, + "match_type": self.match_type, + "success_rate": self.procedure.success_rate, + } + + +@dataclass +class MatchContext: + """Context for procedure matching.""" + + query: str + task_type: str | None = None + project_id: Any | None = None + agent_type_id: Any | None = None + max_results: int = 5 + min_score: float = 0.3 + require_success_rate: float | None = None + + +class ProcedureMatcher: + """ + Matches procedures to contexts using multiple strategies. + + Matching strategies: + - Keyword matching on trigger pattern and name + - Pattern-based matching using regex + - Success rate weighting + + In production, this would be augmented with vector similarity search. + """ + + # Common task-related keywords for boosting + TASK_KEYWORDS: ClassVar[set[str]] = { + "create", + "update", + "delete", + "fix", + "implement", + "add", + "remove", + "refactor", + "test", + "deploy", + "configure", + "setup", + "build", + "debug", + "optimize", + } + + def __init__(self) -> None: + """Initialize the matcher.""" + self._compiled_patterns: dict[str, re.Pattern[str]] = {} + + def match( + self, + procedures: list[Procedure], + context: MatchContext, + ) -> list[MatchResult]: + """ + Match procedures against a context. + + Args: + procedures: List of procedures to match + context: Matching context + + Returns: + List of match results, sorted by score (highest first) + """ + results: list[MatchResult] = [] + + query_terms = self._extract_terms(context.query) + query_lower = context.query.lower() + + for procedure in procedures: + score, matched = self._calculate_match_score( + procedure=procedure, + query_terms=query_terms, + query_lower=query_lower, + context=context, + ) + + if score >= context.min_score: + # Apply success rate boost + if context.require_success_rate is not None: + if procedure.success_rate < context.require_success_rate: + continue + + # Boost score based on success rate + success_boost = procedure.success_rate * 0.2 + final_score = min(1.0, score + success_boost) + + results.append( + MatchResult( + procedure=procedure, + score=final_score, + matched_terms=matched, + match_type="keyword", + ) + ) + + # Sort by score descending + results.sort(key=lambda r: r.score, reverse=True) + + return results[: context.max_results] + + def _extract_terms(self, text: str) -> list[str]: + """Extract searchable terms from text.""" + # Remove special characters and split + clean = re.sub(r"[^\w\s-]", " ", text.lower()) + terms = clean.split() + + # Filter out very short terms + return [t for t in terms if len(t) >= 2] + + def _calculate_match_score( + self, + procedure: Procedure, + query_terms: list[str], + query_lower: str, + context: MatchContext, + ) -> tuple[float, list[str]]: + """ + Calculate match score between procedure and query. + + Returns: + Tuple of (score, matched_terms) + """ + score = 0.0 + matched: list[str] = [] + + trigger_lower = procedure.trigger_pattern.lower() + name_lower = procedure.name.lower() + + # Exact name match - high score + if name_lower in query_lower or query_lower in name_lower: + score += 0.5 + matched.append(f"name:{procedure.name}") + + # Trigger pattern match + if trigger_lower in query_lower or query_lower in trigger_lower: + score += 0.4 + matched.append(f"trigger:{procedure.trigger_pattern[:30]}") + + # Term-by-term matching + for term in query_terms: + if term in trigger_lower: + score += 0.1 + matched.append(term) + elif term in name_lower: + score += 0.08 + matched.append(term) + + # Boost for task keywords + if term in self.TASK_KEYWORDS: + if term in trigger_lower or term in name_lower: + score += 0.05 + + # Task type match if provided + if context.task_type: + task_type_lower = context.task_type.lower() + if task_type_lower in trigger_lower or task_type_lower in name_lower: + score += 0.3 + matched.append(f"task_type:{context.task_type}") + + # Regex pattern matching on trigger + try: + pattern = self._get_or_compile_pattern(trigger_lower) + if pattern and pattern.search(query_lower): + score += 0.25 + matched.append("pattern_match") + except re.error: + pass # Invalid regex, skip pattern matching + + return min(1.0, score), matched + + def _get_or_compile_pattern(self, pattern: str) -> re.Pattern[str] | None: + """Get or compile a regex pattern with caching.""" + if pattern in self._compiled_patterns: + return self._compiled_patterns[pattern] + + # Only compile if it looks like a regex pattern + if not any(c in pattern for c in r"\.*+?[]{}|()^$"): + return None + + try: + compiled = re.compile(pattern, re.IGNORECASE) + self._compiled_patterns[pattern] = compiled + return compiled + except re.error: + return None + + def rank_by_relevance( + self, + procedures: list[Procedure], + task_type: str, + ) -> list[Procedure]: + """ + Rank procedures by relevance to a task type. + + Args: + procedures: Procedures to rank + task_type: Task type for relevance + + Returns: + Procedures sorted by relevance + """ + context = MatchContext( + query=task_type, + task_type=task_type, + min_score=0.0, + max_results=len(procedures), + ) + + results = self.match(procedures, context) + return [r.procedure for r in results] + + def suggest_procedures( + self, + procedures: list[Procedure], + query: str, + min_success_rate: float = 0.5, + max_suggestions: int = 3, + ) -> list[MatchResult]: + """ + Suggest the best procedures for a query. + + Only suggests procedures with sufficient success rate. + + Args: + procedures: Available procedures + query: Query/context + min_success_rate: Minimum success rate to suggest + max_suggestions: Maximum suggestions + + Returns: + List of procedure suggestions + """ + context = MatchContext( + query=query, + max_results=max_suggestions, + min_score=0.2, + require_success_rate=min_success_rate, + ) + + return self.match(procedures, context) + + +# Singleton matcher instance +_matcher: ProcedureMatcher | None = None + + +def get_procedure_matcher() -> ProcedureMatcher: + """Get the singleton procedure matcher instance.""" + global _matcher + if _matcher is None: + _matcher = ProcedureMatcher() + return _matcher diff --git a/backend/app/services/memory/procedural/memory.py b/backend/app/services/memory/procedural/memory.py new file mode 100644 index 0000000..dca4eeb --- /dev/null +++ b/backend/app/services/memory/procedural/memory.py @@ -0,0 +1,724 @@ +# app/services/memory/procedural/memory.py +""" +Procedural Memory Implementation. + +Provides storage and retrieval for learned procedures (skills) +derived from successful task execution patterns. +""" + +import logging +import time +from datetime import UTC, datetime +from typing import Any +from uuid import UUID + +from sqlalchemy import and_, desc, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.memory.procedure import Procedure as ProcedureModel +from app.services.memory.config import get_memory_settings +from app.services.memory.types import Procedure, ProcedureCreate, RetrievalResult, Step + +logger = logging.getLogger(__name__) + + +def _model_to_procedure(model: ProcedureModel) -> Procedure: + """Convert SQLAlchemy model to Procedure dataclass.""" + return Procedure( + id=model.id, # type: ignore[arg-type] + project_id=model.project_id, # type: ignore[arg-type] + agent_type_id=model.agent_type_id, # type: ignore[arg-type] + name=model.name, # type: ignore[arg-type] + trigger_pattern=model.trigger_pattern, # type: ignore[arg-type] + steps=model.steps or [], # type: ignore[arg-type] + success_count=model.success_count, # type: ignore[arg-type] + failure_count=model.failure_count, # type: ignore[arg-type] + last_used=model.last_used, # type: ignore[arg-type] + embedding=None, # Don't expose raw embedding + created_at=model.created_at, # type: ignore[arg-type] + updated_at=model.updated_at, # type: ignore[arg-type] + ) + + +class ProceduralMemory: + """ + Procedural Memory Service. + + Provides procedure storage and retrieval: + - Record procedures from successful task patterns + - Find matching procedures by trigger pattern + - Track success/failure rates + - Get best procedure for a task type + - Update procedure steps + + Performance target: <50ms P95 for matching + """ + + def __init__( + self, + session: AsyncSession, + embedding_generator: Any | None = None, + ) -> None: + """ + Initialize procedural memory. + + Args: + session: Database session + embedding_generator: Optional embedding generator for semantic matching + """ + self._session = session + self._embedding_generator = embedding_generator + self._settings = get_memory_settings() + + @classmethod + async def create( + cls, + session: AsyncSession, + embedding_generator: Any | None = None, + ) -> "ProceduralMemory": + """ + Factory method to create ProceduralMemory. + + Args: + session: Database session + embedding_generator: Optional embedding generator + + Returns: + Configured ProceduralMemory instance + """ + return cls(session=session, embedding_generator=embedding_generator) + + # ========================================================================= + # Procedure Recording + # ========================================================================= + + async def record_procedure(self, procedure: ProcedureCreate) -> Procedure: + """ + Record a new procedure or update an existing one. + + If a procedure with the same name exists in the same scope, + its steps will be updated and success count incremented. + + Args: + procedure: Procedure data to record + + Returns: + The created or updated procedure + """ + # Check for existing procedure with same name + existing = await self._find_existing_procedure( + project_id=procedure.project_id, + agent_type_id=procedure.agent_type_id, + name=procedure.name, + ) + + if existing is not None: + # Update existing procedure + return await self._update_existing_procedure( + existing=existing, + new_steps=procedure.steps, + new_trigger=procedure.trigger_pattern, + ) + + # Create new procedure + now = datetime.now(UTC) + + # Generate embedding if possible + embedding = None + if self._embedding_generator is not None: + embedding_text = self._create_embedding_text(procedure) + embedding = await self._embedding_generator.generate(embedding_text) + + model = ProcedureModel( + project_id=procedure.project_id, + agent_type_id=procedure.agent_type_id, + name=procedure.name, + trigger_pattern=procedure.trigger_pattern, + steps=procedure.steps, + success_count=1, # New procedures start with 1 success (they worked) + failure_count=0, + last_used=now, + embedding=embedding, + ) + + self._session.add(model) + await self._session.flush() + await self._session.refresh(model) + + logger.info( + f"Recorded new procedure: {procedure.name} with {len(procedure.steps)} steps" + ) + + return _model_to_procedure(model) + + async def _find_existing_procedure( + self, + project_id: UUID | None, + agent_type_id: UUID | None, + name: str, + ) -> ProcedureModel | None: + """Find an existing procedure with the same name in the same scope.""" + query = select(ProcedureModel).where(ProcedureModel.name == name) + + if project_id is not None: + query = query.where(ProcedureModel.project_id == project_id) + else: + query = query.where(ProcedureModel.project_id.is_(None)) + + if agent_type_id is not None: + query = query.where(ProcedureModel.agent_type_id == agent_type_id) + else: + query = query.where(ProcedureModel.agent_type_id.is_(None)) + + result = await self._session.execute(query) + return result.scalar_one_or_none() + + async def _update_existing_procedure( + self, + existing: ProcedureModel, + new_steps: list[dict[str, Any]], + new_trigger: str, + ) -> Procedure: + """Update an existing procedure with new steps.""" + now = datetime.now(UTC) + + # Merge steps intelligently - keep existing order, add new steps + merged_steps = self._merge_steps( + existing.steps or [], # type: ignore[arg-type] + new_steps, + ) + + stmt = ( + update(ProcedureModel) + .where(ProcedureModel.id == existing.id) + .values( + steps=merged_steps, + trigger_pattern=new_trigger, + success_count=ProcedureModel.success_count + 1, + last_used=now, + updated_at=now, + ) + .returning(ProcedureModel) + ) + + result = await self._session.execute(stmt) + updated_model = result.scalar_one() + await self._session.flush() + + logger.info(f"Updated existing procedure: {existing.name}") + + return _model_to_procedure(updated_model) + + def _merge_steps( + self, + existing_steps: list[dict[str, Any]], + new_steps: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """Merge steps from a new execution with existing steps.""" + if not existing_steps: + return new_steps + if not new_steps: + return existing_steps + + # For now, use the new steps if they differ significantly + # In production, this could use more sophisticated merging + if len(new_steps) != len(existing_steps): + # If structure changed, prefer newer steps + return new_steps + + # Merge step-by-step, preferring new data where available + merged = [] + for i, new_step in enumerate(new_steps): + if i < len(existing_steps): + # Merge with existing step + step = {**existing_steps[i], **new_step} + else: + step = new_step + merged.append(step) + + return merged + + def _create_embedding_text(self, procedure: ProcedureCreate) -> str: + """Create text for embedding from procedure data.""" + steps_text = " ".join(step.get("action", "") for step in procedure.steps) + return f"{procedure.name} {procedure.trigger_pattern} {steps_text}" + + # ========================================================================= + # Procedure Retrieval + # ========================================================================= + + async def find_matching( + self, + context: str, + project_id: UUID | None = None, + agent_type_id: UUID | None = None, + limit: int = 5, + ) -> list[Procedure]: + """ + Find procedures matching the given context. + + Args: + context: Context/trigger to match against + project_id: Optional project to search within + agent_type_id: Optional agent type filter + limit: Maximum results + + Returns: + List of matching procedures + """ + result = await self._find_matching_with_metadata( + context=context, + project_id=project_id, + agent_type_id=agent_type_id, + limit=limit, + ) + return result.items + + async def _find_matching_with_metadata( + self, + context: str, + project_id: UUID | None = None, + agent_type_id: UUID | None = None, + limit: int = 5, + ) -> RetrievalResult[Procedure]: + """Find matching procedures with full result metadata.""" + start_time = time.perf_counter() + + # Build base query - prioritize by success rate + stmt = ( + select(ProcedureModel) + .order_by( + desc( + ProcedureModel.success_count + / (ProcedureModel.success_count + ProcedureModel.failure_count + 1) + ), + desc(ProcedureModel.last_used), + ) + .limit(limit) + ) + + # Apply scope filters + if project_id is not None: + stmt = stmt.where( + or_( + ProcedureModel.project_id == project_id, + ProcedureModel.project_id.is_(None), + ) + ) + + if agent_type_id is not None: + stmt = stmt.where( + or_( + ProcedureModel.agent_type_id == agent_type_id, + ProcedureModel.agent_type_id.is_(None), + ) + ) + + # Text-based matching on trigger pattern and name + # TODO: Implement proper vector similarity search when pgvector is integrated + search_terms = context.lower().split()[:5] # Limit to 5 terms + if search_terms: + conditions = [] + for term in search_terms: + term_pattern = f"%{term}%" + conditions.append( + or_( + ProcedureModel.trigger_pattern.ilike(term_pattern), + ProcedureModel.name.ilike(term_pattern), + ) + ) + if conditions: + stmt = stmt.where(or_(*conditions)) + + result = await self._session.execute(stmt) + models = list(result.scalars().all()) + + latency_ms = (time.perf_counter() - start_time) * 1000 + + return RetrievalResult( + items=[_model_to_procedure(m) for m in models], + total_count=len(models), + query=context, + retrieval_type="procedural", + latency_ms=latency_ms, + metadata={"project_id": str(project_id) if project_id else None}, + ) + + async def get_best_procedure( + self, + task_type: str, + project_id: UUID | None = None, + agent_type_id: UUID | None = None, + min_success_rate: float = 0.5, + min_uses: int = 1, + ) -> Procedure | None: + """ + Get the best procedure for a given task type. + + Returns the procedure with the highest success rate that + meets the minimum thresholds. + + Args: + task_type: Task type to find procedure for + project_id: Optional project scope + agent_type_id: Optional agent type scope + min_success_rate: Minimum required success rate + min_uses: Minimum number of uses required + + Returns: + Best matching procedure or None + """ + # Build query for procedures matching task type + stmt = ( + select(ProcedureModel) + .where( + and_( + (ProcedureModel.success_count + ProcedureModel.failure_count) + >= min_uses, + or_( + ProcedureModel.trigger_pattern.ilike(f"%{task_type}%"), + ProcedureModel.name.ilike(f"%{task_type}%"), + ), + ) + ) + .order_by( + desc( + ProcedureModel.success_count + / (ProcedureModel.success_count + ProcedureModel.failure_count + 1) + ), + desc(ProcedureModel.last_used), + ) + .limit(10) + ) + + # Apply scope filters + if project_id is not None: + stmt = stmt.where( + or_( + ProcedureModel.project_id == project_id, + ProcedureModel.project_id.is_(None), + ) + ) + + if agent_type_id is not None: + stmt = stmt.where( + or_( + ProcedureModel.agent_type_id == agent_type_id, + ProcedureModel.agent_type_id.is_(None), + ) + ) + + result = await self._session.execute(stmt) + models = list(result.scalars().all()) + + # Filter by success rate in Python (SQLAlchemy division in WHERE is complex) + for model in models: + success = float(model.success_count) + failure = float(model.failure_count) + total = success + failure + if total > 0 and (success / total) >= min_success_rate: + logger.debug( + f"Found best procedure for '{task_type}': {model.name} " + f"(success_rate={success / total:.2%})" + ) + return _model_to_procedure(model) + + return None + + async def get_by_id(self, procedure_id: UUID) -> Procedure | None: + """Get a procedure by ID.""" + query = select(ProcedureModel).where(ProcedureModel.id == procedure_id) + result = await self._session.execute(query) + model = result.scalar_one_or_none() + return _model_to_procedure(model) if model else None + + # ========================================================================= + # Outcome Recording + # ========================================================================= + + async def record_outcome( + self, + procedure_id: UUID, + success: bool, + ) -> Procedure: + """ + Record the outcome of using a procedure. + + Updates the success or failure count and last_used timestamp. + + Args: + procedure_id: Procedure that was used + success: Whether the procedure succeeded + + Returns: + Updated procedure + + Raises: + ValueError: If procedure not found + """ + query = select(ProcedureModel).where(ProcedureModel.id == procedure_id) + result = await self._session.execute(query) + model = result.scalar_one_or_none() + + if model is None: + raise ValueError(f"Procedure not found: {procedure_id}") + + now = datetime.now(UTC) + + if success: + stmt = ( + update(ProcedureModel) + .where(ProcedureModel.id == procedure_id) + .values( + success_count=ProcedureModel.success_count + 1, + last_used=now, + updated_at=now, + ) + .returning(ProcedureModel) + ) + else: + stmt = ( + update(ProcedureModel) + .where(ProcedureModel.id == procedure_id) + .values( + failure_count=ProcedureModel.failure_count + 1, + last_used=now, + updated_at=now, + ) + .returning(ProcedureModel) + ) + + result = await self._session.execute(stmt) + updated_model = result.scalar_one() + await self._session.flush() + + outcome = "success" if success else "failure" + logger.info( + f"Recorded {outcome} for procedure {procedure_id}: " + f"success_rate={updated_model.success_rate:.2%}" + ) + + return _model_to_procedure(updated_model) + + # ========================================================================= + # Step Management + # ========================================================================= + + async def update_steps( + self, + procedure_id: UUID, + steps: list[Step], + ) -> Procedure: + """ + Update the steps of a procedure. + + Args: + procedure_id: Procedure to update + steps: New steps + + Returns: + Updated procedure + + Raises: + ValueError: If procedure not found + """ + query = select(ProcedureModel).where(ProcedureModel.id == procedure_id) + result = await self._session.execute(query) + model = result.scalar_one_or_none() + + if model is None: + raise ValueError(f"Procedure not found: {procedure_id}") + + # Convert Step objects to dictionaries + steps_dict = [ + { + "order": step.order, + "action": step.action, + "parameters": step.parameters, + "expected_outcome": step.expected_outcome, + "fallback_action": step.fallback_action, + } + for step in steps + ] + + now = datetime.now(UTC) + stmt = ( + update(ProcedureModel) + .where(ProcedureModel.id == procedure_id) + .values( + steps=steps_dict, + updated_at=now, + ) + .returning(ProcedureModel) + ) + + result = await self._session.execute(stmt) + updated_model = result.scalar_one() + await self._session.flush() + + logger.info(f"Updated steps for procedure {procedure_id}: {len(steps)} steps") + + return _model_to_procedure(updated_model) + + # ========================================================================= + # Statistics & Management + # ========================================================================= + + async def get_stats( + self, + project_id: UUID | None = None, + agent_type_id: UUID | None = None, + ) -> dict[str, Any]: + """ + Get statistics about procedural memory. + + Args: + project_id: Optional project to get stats for + agent_type_id: Optional agent type filter + + Returns: + Dictionary with statistics + """ + query = select(ProcedureModel) + + if project_id is not None: + query = query.where( + or_( + ProcedureModel.project_id == project_id, + ProcedureModel.project_id.is_(None), + ) + ) + + if agent_type_id is not None: + query = query.where( + or_( + ProcedureModel.agent_type_id == agent_type_id, + ProcedureModel.agent_type_id.is_(None), + ) + ) + + result = await self._session.execute(query) + models = list(result.scalars().all()) + + if not models: + return { + "total_procedures": 0, + "avg_success_rate": 0.0, + "avg_steps_count": 0.0, + "total_uses": 0, + "high_success_count": 0, + "low_success_count": 0, + } + + success_rates = [m.success_rate for m in models] + step_counts = [len(m.steps or []) for m in models] + total_uses = sum(m.total_uses for m in models) + + return { + "total_procedures": len(models), + "avg_success_rate": sum(success_rates) / len(success_rates), + "avg_steps_count": sum(step_counts) / len(step_counts), + "total_uses": total_uses, + "high_success_count": sum(1 for r in success_rates if r >= 0.8), + "low_success_count": sum(1 for r in success_rates if r < 0.5), + } + + async def count( + self, + project_id: UUID | None = None, + agent_type_id: UUID | None = None, + ) -> int: + """ + Count procedures in scope. + + Args: + project_id: Optional project to count for + agent_type_id: Optional agent type filter + + Returns: + Number of procedures + """ + query = select(ProcedureModel) + + if project_id is not None: + query = query.where( + or_( + ProcedureModel.project_id == project_id, + ProcedureModel.project_id.is_(None), + ) + ) + + if agent_type_id is not None: + query = query.where( + or_( + ProcedureModel.agent_type_id == agent_type_id, + ProcedureModel.agent_type_id.is_(None), + ) + ) + + result = await self._session.execute(query) + return len(list(result.scalars().all())) + + async def delete(self, procedure_id: UUID) -> bool: + """ + Delete a procedure. + + Args: + procedure_id: Procedure to delete + + Returns: + True if deleted, False if not found + """ + query = select(ProcedureModel).where(ProcedureModel.id == procedure_id) + result = await self._session.execute(query) + model = result.scalar_one_or_none() + + if model is None: + return False + + await self._session.delete(model) + await self._session.flush() + + logger.info(f"Deleted procedure {procedure_id}") + return True + + async def get_procedures_by_success_rate( + self, + min_rate: float = 0.0, + max_rate: float = 1.0, + project_id: UUID | None = None, + limit: int = 20, + ) -> list[Procedure]: + """ + Get procedures within a success rate range. + + Args: + min_rate: Minimum success rate + max_rate: Maximum success rate + project_id: Optional project scope + limit: Maximum results + + Returns: + List of procedures + """ + query = ( + select(ProcedureModel) + .order_by(desc(ProcedureModel.last_used)) + .limit(limit * 2) # Fetch more since we filter in Python + ) + + if project_id is not None: + query = query.where( + or_( + ProcedureModel.project_id == project_id, + ProcedureModel.project_id.is_(None), + ) + ) + + result = await self._session.execute(query) + models = list(result.scalars().all()) + + # Filter by success rate in Python + filtered = [m for m in models if min_rate <= m.success_rate <= max_rate][:limit] + + return [_model_to_procedure(m) for m in filtered] diff --git a/backend/tests/unit/services/memory/procedural/__init__.py b/backend/tests/unit/services/memory/procedural/__init__.py new file mode 100644 index 0000000..6cb8d83 --- /dev/null +++ b/backend/tests/unit/services/memory/procedural/__init__.py @@ -0,0 +1,2 @@ +# tests/unit/services/memory/procedural/__init__.py +"""Unit tests for procedural memory.""" diff --git a/backend/tests/unit/services/memory/procedural/test_matching.py b/backend/tests/unit/services/memory/procedural/test_matching.py new file mode 100644 index 0000000..84b4c14 --- /dev/null +++ b/backend/tests/unit/services/memory/procedural/test_matching.py @@ -0,0 +1,427 @@ +# tests/unit/services/memory/procedural/test_matching.py +"""Unit tests for procedure matching.""" + +from datetime import UTC, datetime +from uuid import uuid4 + +import pytest + +from app.services.memory.procedural.matching import ( + MatchContext, + MatchResult, + ProcedureMatcher, + get_procedure_matcher, +) +from app.services.memory.types import Procedure + + +def create_test_procedure( + name: str = "deploy_api", + trigger_pattern: str = "deploy.*api", + success_count: int = 8, + failure_count: int = 2, +) -> Procedure: + """Create a test procedure for testing.""" + now = datetime.now(UTC) + return Procedure( + id=uuid4(), + project_id=None, + agent_type_id=None, + name=name, + trigger_pattern=trigger_pattern, + steps=[ + {"order": 1, "action": "build"}, + {"order": 2, "action": "test"}, + {"order": 3, "action": "deploy"}, + ], + success_count=success_count, + failure_count=failure_count, + last_used=now, + embedding=None, + created_at=now, + updated_at=now, + ) + + +class TestMatchResult: + """Tests for MatchResult dataclass.""" + + def test_to_dict(self) -> None: + """Test converting match result to dictionary.""" + procedure = create_test_procedure() + result = MatchResult( + procedure=procedure, + score=0.85, + matched_terms=["deploy", "api"], + match_type="keyword", + ) + + data = result.to_dict() + + assert "procedure_id" in data + assert "procedure_name" in data + assert data["score"] == 0.85 + assert data["matched_terms"] == ["deploy", "api"] + assert data["match_type"] == "keyword" + assert data["success_rate"] == 0.8 + + +class TestMatchContext: + """Tests for MatchContext dataclass.""" + + def test_default_values(self) -> None: + """Test default values.""" + context = MatchContext(query="deploy api") + + assert context.query == "deploy api" + assert context.task_type is None + assert context.project_id is None + assert context.max_results == 5 + assert context.min_score == 0.3 + assert context.require_success_rate is None + + def test_with_all_values(self) -> None: + """Test with all values set.""" + project_id = uuid4() + context = MatchContext( + query="deploy api", + task_type="deployment", + project_id=project_id, + max_results=10, + min_score=0.5, + require_success_rate=0.7, + ) + + assert context.query == "deploy api" + assert context.task_type == "deployment" + assert context.project_id == project_id + assert context.max_results == 10 + assert context.min_score == 0.5 + assert context.require_success_rate == 0.7 + + +class TestProcedureMatcher: + """Tests for ProcedureMatcher class.""" + + @pytest.fixture + def matcher(self) -> ProcedureMatcher: + """Create a procedure matcher.""" + return ProcedureMatcher() + + @pytest.fixture + def procedures(self) -> list[Procedure]: + """Create test procedures.""" + return [ + create_test_procedure( + name="deploy_api", + trigger_pattern="deploy.*api", + success_count=9, + failure_count=1, + ), + create_test_procedure( + name="deploy_frontend", + trigger_pattern="deploy.*frontend", + success_count=7, + failure_count=3, + ), + create_test_procedure( + name="build_project", + trigger_pattern="build.*project", + success_count=8, + failure_count=2, + ), + create_test_procedure( + name="run_tests", + trigger_pattern="test.*run", + success_count=5, + failure_count=5, + ), + ] + + def test_match_exact_name( + self, + matcher: ProcedureMatcher, + procedures: list[Procedure], + ) -> None: + """Test matching with exact name.""" + context = MatchContext(query="deploy_api") + + results = matcher.match(procedures, context) + + assert len(results) > 0 + # First result should be deploy_api + assert results[0].procedure.name == "deploy_api" + + def test_match_partial_terms( + self, + matcher: ProcedureMatcher, + procedures: list[Procedure], + ) -> None: + """Test matching with partial terms.""" + context = MatchContext(query="deploy") + + results = matcher.match(procedures, context) + + assert len(results) >= 2 + # Both deploy procedures should match + names = [r.procedure.name for r in results] + assert "deploy_api" in names + assert "deploy_frontend" in names + + def test_match_with_task_type( + self, + matcher: ProcedureMatcher, + procedures: list[Procedure], + ) -> None: + """Test matching with task type.""" + context = MatchContext( + query="build something", + task_type="build", + ) + + results = matcher.match(procedures, context) + + assert len(results) > 0 + assert results[0].procedure.name == "build_project" + + def test_match_respects_min_score( + self, + matcher: ProcedureMatcher, + procedures: list[Procedure], + ) -> None: + """Test that matching respects minimum score.""" + context = MatchContext( + query="completely unrelated query xyz", + min_score=0.5, + ) + + results = matcher.match(procedures, context) + + # Should not match anything with high min_score + for result in results: + assert result.score >= 0.5 + + def test_match_respects_success_rate_requirement( + self, + matcher: ProcedureMatcher, + procedures: list[Procedure], + ) -> None: + """Test that matching respects success rate requirement.""" + context = MatchContext( + query="deploy", + require_success_rate=0.7, + ) + + results = matcher.match(procedures, context) + + for result in results: + assert result.procedure.success_rate >= 0.7 + + def test_match_respects_max_results( + self, + matcher: ProcedureMatcher, + procedures: list[Procedure], + ) -> None: + """Test that matching respects max results.""" + context = MatchContext( + query="deploy", + max_results=1, + min_score=0.0, + ) + + results = matcher.match(procedures, context) + + assert len(results) <= 1 + + def test_match_sorts_by_score( + self, + matcher: ProcedureMatcher, + procedures: list[Procedure], + ) -> None: + """Test that results are sorted by score.""" + context = MatchContext(query="deploy", min_score=0.0) + + results = matcher.match(procedures, context) + + if len(results) > 1: + scores = [r.score for r in results] + assert scores == sorted(scores, reverse=True) + + def test_match_empty_procedures( + self, + matcher: ProcedureMatcher, + ) -> None: + """Test matching with empty procedures list.""" + context = MatchContext(query="deploy") + + results = matcher.match([], context) + + assert results == [] + + +class TestProcedureMatcherRankByRelevance: + """Tests for rank_by_relevance method.""" + + @pytest.fixture + def matcher(self) -> ProcedureMatcher: + """Create a procedure matcher.""" + return ProcedureMatcher() + + def test_rank_by_relevance( + self, + matcher: ProcedureMatcher, + ) -> None: + """Test ranking by relevance.""" + procedures = [ + create_test_procedure(name="unrelated", trigger_pattern="something else"), + create_test_procedure(name="deploy_api", trigger_pattern="deploy.*api"), + create_test_procedure( + name="deploy_frontend", trigger_pattern="deploy.*frontend" + ), + ] + + ranked = matcher.rank_by_relevance(procedures, "deploy") + + # Deploy procedures should be ranked first + assert ranked[0].name in ["deploy_api", "deploy_frontend"] + + def test_rank_by_relevance_empty( + self, + matcher: ProcedureMatcher, + ) -> None: + """Test ranking empty list.""" + ranked = matcher.rank_by_relevance([], "deploy") + + assert ranked == [] + + +class TestProcedureMatcherSuggestProcedures: + """Tests for suggest_procedures method.""" + + @pytest.fixture + def matcher(self) -> ProcedureMatcher: + """Create a procedure matcher.""" + return ProcedureMatcher() + + @pytest.fixture + def procedures(self) -> list[Procedure]: + """Create test procedures.""" + return [ + create_test_procedure( + name="deploy_api", + trigger_pattern="deploy api", + success_count=9, + failure_count=1, + ), + create_test_procedure( + name="bad_deploy", + trigger_pattern="deploy bad", + success_count=2, + failure_count=8, + ), + ] + + def test_suggest_procedures( + self, + matcher: ProcedureMatcher, + procedures: list[Procedure], + ) -> None: + """Test suggesting procedures.""" + suggestions = matcher.suggest_procedures( + procedures, + "deploy", + min_success_rate=0.5, + ) + + assert len(suggestions) > 0 + # Only high success rate should be suggested + for s in suggestions: + assert s.procedure.success_rate >= 0.5 + + def test_suggest_procedures_limits_results( + self, + matcher: ProcedureMatcher, + procedures: list[Procedure], + ) -> None: + """Test that suggestions are limited.""" + suggestions = matcher.suggest_procedures( + procedures, + "deploy", + max_suggestions=1, + ) + + assert len(suggestions) <= 1 + + +class TestGetProcedureMatcher: + """Tests for singleton getter.""" + + def test_get_procedure_matcher_returns_instance(self) -> None: + """Test that getter returns instance.""" + matcher = get_procedure_matcher() + + assert matcher is not None + assert isinstance(matcher, ProcedureMatcher) + + def test_get_procedure_matcher_returns_same_instance(self) -> None: + """Test that getter returns same instance (singleton).""" + matcher1 = get_procedure_matcher() + matcher2 = get_procedure_matcher() + + assert matcher1 is matcher2 + + +class TestProcedureMatcherExtractTerms: + """Tests for term extraction.""" + + @pytest.fixture + def matcher(self) -> ProcedureMatcher: + """Create a procedure matcher.""" + return ProcedureMatcher() + + def test_extract_terms_basic( + self, + matcher: ProcedureMatcher, + ) -> None: + """Test basic term extraction.""" + terms = matcher._extract_terms("deploy the api") + + assert "deploy" in terms + assert "the" in terms + assert "api" in terms + + def test_extract_terms_removes_special_chars( + self, + matcher: ProcedureMatcher, + ) -> None: + """Test that special characters are removed.""" + terms = matcher._extract_terms("deploy.api!now") + + assert "deploy" in terms + assert "api" in terms + assert "now" in terms + assert "." not in terms + assert "!" not in terms + + def test_extract_terms_filters_short( + self, + matcher: ProcedureMatcher, + ) -> None: + """Test that short terms are filtered.""" + terms = matcher._extract_terms("a big api") + + assert "a" not in terms + assert "big" in terms + assert "api" in terms + + def test_extract_terms_lowercases( + self, + matcher: ProcedureMatcher, + ) -> None: + """Test that terms are lowercased.""" + terms = matcher._extract_terms("Deploy API") + + assert "deploy" in terms + assert "api" in terms + assert "Deploy" not in terms + assert "API" not in terms diff --git a/backend/tests/unit/services/memory/procedural/test_memory.py b/backend/tests/unit/services/memory/procedural/test_memory.py new file mode 100644 index 0000000..719a046 --- /dev/null +++ b/backend/tests/unit/services/memory/procedural/test_memory.py @@ -0,0 +1,569 @@ +# tests/unit/services/memory/procedural/test_memory.py +"""Unit tests for ProceduralMemory class.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest + +from app.services.memory.procedural.memory import ProceduralMemory +from app.services.memory.types import ProcedureCreate, Step + + +def create_mock_procedure_model( + name="deploy_api", + trigger_pattern="deploy.*api", + project_id=None, + agent_type_id=None, + success_count=5, + failure_count=1, +): + """Create a mock procedure model for testing.""" + mock = MagicMock() + mock.id = uuid4() + mock.project_id = project_id + mock.agent_type_id = agent_type_id + mock.name = name + mock.trigger_pattern = trigger_pattern + mock.steps = [ + {"order": 1, "action": "build", "parameters": {}}, + {"order": 2, "action": "test", "parameters": {}}, + {"order": 3, "action": "deploy", "parameters": {}}, + ] + mock.success_count = success_count + mock.failure_count = failure_count + mock.last_used = datetime.now(UTC) + mock.embedding = None + mock.created_at = datetime.now(UTC) + mock.updated_at = datetime.now(UTC) + mock.success_rate = ( + success_count / (success_count + failure_count) + if (success_count + failure_count) > 0 + else 0.0 + ) + mock.total_uses = success_count + failure_count + return mock + + +class TestProceduralMemoryInit: + """Tests for ProceduralMemory initialization.""" + + def test_init_creates_memory(self) -> None: + """Test that init creates memory instance.""" + mock_session = AsyncMock() + memory = ProceduralMemory(session=mock_session) + + assert memory._session is mock_session + + def test_init_with_embedding_generator(self) -> None: + """Test init with embedding generator.""" + mock_session = AsyncMock() + mock_embedding_gen = AsyncMock() + memory = ProceduralMemory( + session=mock_session, embedding_generator=mock_embedding_gen + ) + + assert memory._embedding_generator is mock_embedding_gen + + @pytest.mark.asyncio + async def test_create_factory_method(self) -> None: + """Test create factory method.""" + mock_session = AsyncMock() + memory = await ProceduralMemory.create(session=mock_session) + + assert memory is not None + assert memory._session is mock_session + + +class TestProceduralMemoryRecordProcedure: + """Tests for procedure recording methods.""" + + @pytest.fixture + def mock_session(self) -> AsyncMock: + """Create a mock database session.""" + session = AsyncMock() + session.add = MagicMock() + session.flush = AsyncMock() + session.refresh = AsyncMock() + # Mock no existing procedure found + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + session.execute.return_value = mock_result + return session + + @pytest.fixture + def memory(self, mock_session: AsyncMock) -> ProceduralMemory: + """Create a ProceduralMemory instance.""" + return ProceduralMemory(session=mock_session) + + @pytest.mark.asyncio + async def test_record_new_procedure( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test recording a new procedure.""" + procedure_data = ProcedureCreate( + name="build_project", + trigger_pattern="build.*project", + steps=[ + {"order": 1, "action": "npm install"}, + {"order": 2, "action": "npm run build"}, + ], + project_id=uuid4(), + ) + + result = await memory.record_procedure(procedure_data) + + assert result.name == "build_project" + assert result.trigger_pattern == "build.*project" + mock_session.add.assert_called_once() + mock_session.flush.assert_called_once() + + @pytest.mark.asyncio + async def test_record_updates_existing( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test that recording duplicate procedure updates existing.""" + # Mock existing procedure found + existing_mock = create_mock_procedure_model() + find_result = MagicMock() + find_result.scalar_one_or_none.return_value = existing_mock + + # Mock update result + updated_mock = create_mock_procedure_model(success_count=6) + update_result = MagicMock() + update_result.scalar_one.return_value = updated_mock + + mock_session.execute.side_effect = [find_result, update_result] + + procedure_data = ProcedureCreate( + name="deploy_api", + trigger_pattern="deploy.*api", + steps=[{"order": 1, "action": "deploy"}], + ) + + _ = await memory.record_procedure(procedure_data) + + # Should have called execute twice (find + update) + assert mock_session.execute.call_count == 2 + + +class TestProceduralMemoryFindMatching: + """Tests for procedure matching methods.""" + + @pytest.fixture + def mock_session(self) -> AsyncMock: + """Create a mock database session.""" + session = AsyncMock() + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + session.execute.return_value = mock_result + return session + + @pytest.fixture + def memory(self, mock_session: AsyncMock) -> ProceduralMemory: + """Create a ProceduralMemory instance.""" + return ProceduralMemory(session=mock_session) + + @pytest.mark.asyncio + async def test_find_matching( + self, + memory: ProceduralMemory, + ) -> None: + """Test finding matching procedures.""" + results = await memory.find_matching("deploy api") + + assert isinstance(results, list) + + @pytest.mark.asyncio + async def test_find_matching_with_project_filter( + self, + memory: ProceduralMemory, + ) -> None: + """Test finding matching procedures with project filter.""" + project_id = uuid4() + results = await memory.find_matching( + "deploy api", + project_id=project_id, + ) + + assert isinstance(results, list) + + @pytest.mark.asyncio + async def test_find_matching_returns_results( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test that find_matching returns results.""" + procedures = [ + create_mock_procedure_model(name="deploy_api"), + create_mock_procedure_model(name="deploy_frontend"), + ] + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = procedures + mock_session.execute.return_value = mock_result + + results = await memory.find_matching("deploy") + + assert len(results) == 2 + + +class TestProceduralMemoryGetBestProcedure: + """Tests for get_best_procedure method.""" + + @pytest.fixture + def mock_session(self) -> AsyncMock: + """Create a mock database session.""" + session = AsyncMock() + return session + + @pytest.fixture + def memory(self, mock_session: AsyncMock) -> ProceduralMemory: + """Create a ProceduralMemory instance.""" + return ProceduralMemory(session=mock_session) + + @pytest.mark.asyncio + async def test_get_best_procedure_none( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test get_best_procedure returns None when no match.""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = mock_result + + result = await memory.get_best_procedure("unknown_task") + + assert result is None + + @pytest.mark.asyncio + async def test_get_best_procedure_returns_highest_success_rate( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test get_best_procedure returns highest success rate.""" + low_success = create_mock_procedure_model( + name="deploy_v1", success_count=3, failure_count=7 + ) + high_success = create_mock_procedure_model( + name="deploy_v2", success_count=9, failure_count=1 + ) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [high_success, low_success] + mock_session.execute.return_value = mock_result + + result = await memory.get_best_procedure("deploy") + + assert result is not None + assert result.name == "deploy_v2" + + +class TestProceduralMemoryRecordOutcome: + """Tests for outcome recording.""" + + @pytest.fixture + def mock_session(self) -> AsyncMock: + """Create a mock database session.""" + session = AsyncMock() + return session + + @pytest.fixture + def memory(self, mock_session: AsyncMock) -> ProceduralMemory: + """Create a ProceduralMemory instance.""" + return ProceduralMemory(session=mock_session) + + @pytest.mark.asyncio + async def test_record_outcome_success( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test recording a successful outcome.""" + existing_mock = create_mock_procedure_model() + + # First query: find procedure + find_result = MagicMock() + find_result.scalar_one_or_none.return_value = existing_mock + + # Second query: update + updated_mock = create_mock_procedure_model(success_count=6) + update_result = MagicMock() + update_result.scalar_one.return_value = updated_mock + + mock_session.execute.side_effect = [find_result, update_result] + + result = await memory.record_outcome(existing_mock.id, success=True) + + assert result.success_count == 6 + + @pytest.mark.asyncio + async def test_record_outcome_failure( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test recording a failure outcome.""" + existing_mock = create_mock_procedure_model() + + # First query: find procedure + find_result = MagicMock() + find_result.scalar_one_or_none.return_value = existing_mock + + # Second query: update + updated_mock = create_mock_procedure_model(failure_count=2) + update_result = MagicMock() + update_result.scalar_one.return_value = updated_mock + + mock_session.execute.side_effect = [find_result, update_result] + + result = await memory.record_outcome(existing_mock.id, success=False) + + assert result.failure_count == 2 + + @pytest.mark.asyncio + async def test_record_outcome_not_found( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test recording outcome for non-existent procedure raises error.""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + with pytest.raises(ValueError, match="Procedure not found"): + await memory.record_outcome(uuid4(), success=True) + + +class TestProceduralMemoryUpdateSteps: + """Tests for step updates.""" + + @pytest.fixture + def mock_session(self) -> AsyncMock: + """Create a mock database session.""" + session = AsyncMock() + return session + + @pytest.fixture + def memory(self, mock_session: AsyncMock) -> ProceduralMemory: + """Create a ProceduralMemory instance.""" + return ProceduralMemory(session=mock_session) + + @pytest.mark.asyncio + async def test_update_steps( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test updating steps.""" + existing_mock = create_mock_procedure_model() + + # First query: find procedure + find_result = MagicMock() + find_result.scalar_one_or_none.return_value = existing_mock + + # Second query: update + updated_mock = create_mock_procedure_model() + updated_mock.steps = [ + {"order": 1, "action": "new_step_1", "parameters": {}}, + {"order": 2, "action": "new_step_2", "parameters": {}}, + ] + update_result = MagicMock() + update_result.scalar_one.return_value = updated_mock + + mock_session.execute.side_effect = [find_result, update_result] + + new_steps = [ + Step(order=1, action="new_step_1"), + Step(order=2, action="new_step_2"), + ] + + result = await memory.update_steps(existing_mock.id, new_steps) + + assert len(result.steps) == 2 + + @pytest.mark.asyncio + async def test_update_steps_not_found( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test updating steps for non-existent procedure raises error.""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + with pytest.raises(ValueError, match="Procedure not found"): + await memory.update_steps(uuid4(), [Step(order=1, action="test")]) + + +class TestProceduralMemoryStats: + """Tests for statistics methods.""" + + @pytest.fixture + def mock_session(self) -> AsyncMock: + """Create a mock database session.""" + session = AsyncMock() + return session + + @pytest.fixture + def memory(self, mock_session: AsyncMock) -> ProceduralMemory: + """Create a ProceduralMemory instance.""" + return ProceduralMemory(session=mock_session) + + @pytest.mark.asyncio + async def test_get_stats_empty( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test getting stats for empty project.""" + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [] + mock_session.execute.return_value = mock_result + + stats = await memory.get_stats(uuid4()) + + assert stats["total_procedures"] == 0 + assert stats["avg_success_rate"] == 0.0 + + @pytest.mark.asyncio + async def test_get_stats_with_data( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test getting stats with data.""" + procedures = [ + create_mock_procedure_model(success_count=8, failure_count=2), + create_mock_procedure_model(success_count=6, failure_count=4), + ] + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = procedures + mock_session.execute.return_value = mock_result + + stats = await memory.get_stats(uuid4()) + + assert stats["total_procedures"] == 2 + assert stats["total_uses"] == 20 # (8+2) + (6+4) + + @pytest.mark.asyncio + async def test_count( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test counting procedures.""" + procedures = [create_mock_procedure_model() for _ in range(5)] + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = procedures + mock_session.execute.return_value = mock_result + + count = await memory.count(uuid4()) + + assert count == 5 + + +class TestProceduralMemoryDelete: + """Tests for delete operations.""" + + @pytest.fixture + def mock_session(self) -> AsyncMock: + """Create a mock database session.""" + session = AsyncMock() + return session + + @pytest.fixture + def memory(self, mock_session: AsyncMock) -> ProceduralMemory: + """Create a ProceduralMemory instance.""" + return ProceduralMemory(session=mock_session) + + @pytest.mark.asyncio + async def test_delete( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test deleting a procedure.""" + existing_mock = create_mock_procedure_model() + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = existing_mock + mock_session.execute.return_value = mock_result + mock_session.delete = AsyncMock() + + result = await memory.delete(existing_mock.id) + + assert result is True + mock_session.delete.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_not_found( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test deleting non-existent procedure.""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + result = await memory.delete(uuid4()) + + assert result is False + + +class TestProceduralMemoryGetById: + """Tests for get_by_id method.""" + + @pytest.fixture + def mock_session(self) -> AsyncMock: + """Create a mock database session.""" + session = AsyncMock() + return session + + @pytest.fixture + def memory(self, mock_session: AsyncMock) -> ProceduralMemory: + """Create a ProceduralMemory instance.""" + return ProceduralMemory(session=mock_session) + + @pytest.mark.asyncio + async def test_get_by_id( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test getting procedure by ID.""" + existing_mock = create_mock_procedure_model() + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = existing_mock + mock_session.execute.return_value = mock_result + + result = await memory.get_by_id(existing_mock.id) + + assert result is not None + assert result.name == "deploy_api" + + @pytest.mark.asyncio + async def test_get_by_id_not_found( + self, + memory: ProceduralMemory, + mock_session: AsyncMock, + ) -> None: + """Test get_by_id returns None when not found.""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + result = await memory.get_by_id(uuid4()) + + assert result is None