Files
syndarix/backend/app/services/memory/mcp/service.py
Felipe Cardoso 192237e69b fix(memory): unify Outcome enum and add ABANDONED support
- Add ABANDONED value to core Outcome enum in types.py
- Replace duplicate OutcomeType class in mcp/tools.py with alias to Outcome
- Simplify mcp/service.py to use outcome directly (no more silent mapping)
- Add migration 0006 to extend PostgreSQL episode_outcome enum
- Add missing constraints to migration 0005 (ix_facts_unique_triple_global)

This fixes the semantic issue where ABANDONED outcomes were silently
converted to FAILURE, losing information about task abandonment.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 01:46:48 +01:00

1087 lines
37 KiB
Python

# app/services/memory/mcp/service.py
"""
MCP Tool Service for Agent Memory System.
Executes memory tool calls from agents, routing to appropriate memory services.
All tools are scoped to project/agent context for proper isolation.
"""
import logging
from collections import OrderedDict
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
from uuid import UUID
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.memory.episodic import EpisodicMemory
from app.services.memory.procedural import ProceduralMemory
from app.services.memory.semantic import SemanticMemory
from app.services.memory.types import (
EpisodeCreate,
FactCreate,
Outcome,
ProcedureCreate,
)
from app.services.memory.working import WorkingMemory
from .tools import (
AnalysisType,
ForgetArgs,
GetMemoryStatsArgs,
MemoryType,
OutcomeType,
RecallArgs,
RecordOutcomeArgs,
ReflectArgs,
RememberArgs,
SearchProceduresArgs,
get_tool_definition,
)
logger = logging.getLogger(__name__)
@dataclass
class ToolContext:
"""Context for tool execution - provides scoping information."""
project_id: UUID
agent_instance_id: UUID | None = None
agent_type_id: UUID | None = None
session_id: str | None = None
user_id: UUID | None = None
@dataclass
class ToolResult:
"""Result from a tool execution."""
success: bool
data: Any = None
error: str | None = None
error_code: str | None = None
execution_time_ms: float = 0.0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"success": self.success,
"data": self.data,
"error": self.error,
"error_code": self.error_code,
"execution_time_ms": self.execution_time_ms,
}
class MemoryToolService:
"""
Service for executing memory-related MCP tool calls.
All operations are scoped to the project/agent context provided.
This service coordinates between different memory types.
"""
# Maximum number of working memory sessions to cache (LRU eviction)
MAX_WORKING_SESSIONS = 1000
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize the memory tool service.
Args:
session: Database session for memory operations
embedding_generator: Optional embedding generator for semantic search
"""
self._session = session
self._embedding_generator = embedding_generator
# Lazy-initialized memory services with LRU eviction for working memory
self._working: OrderedDict[str, WorkingMemory] = OrderedDict()
self._episodic: EpisodicMemory | None = None
self._semantic: SemanticMemory | None = None
self._procedural: ProceduralMemory | None = None
async def _get_working(
self,
session_id: str,
project_id: UUID | None = None,
agent_instance_id: UUID | None = None,
) -> WorkingMemory:
"""Get or create working memory for a session with LRU eviction."""
if session_id in self._working:
# Move to end (most recently used)
self._working.move_to_end(session_id)
return self._working[session_id]
# Evict oldest entries if at capacity
while len(self._working) >= self.MAX_WORKING_SESSIONS:
oldest_id, oldest_memory = self._working.popitem(last=False)
try:
await oldest_memory.close()
except Exception as e:
logger.warning(f"Error closing evicted working memory {oldest_id}: {e}")
# Create new working memory
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id) if project_id else None,
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
self._working[session_id] = working
return working
async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service."""
if self._episodic is None:
self._episodic = await EpisodicMemory.create(
self._session,
self._embedding_generator,
)
return self._episodic
async def _get_semantic(self) -> SemanticMemory:
"""Get or create semantic memory service."""
if self._semantic is None:
self._semantic = await SemanticMemory.create(
self._session,
self._embedding_generator,
)
return self._semantic
async def _get_procedural(self) -> ProceduralMemory:
"""Get or create procedural memory service."""
if self._procedural is None:
self._procedural = await ProceduralMemory.create(
self._session,
self._embedding_generator,
)
return self._procedural
async def execute_tool(
self,
tool_name: str,
arguments: dict[str, Any],
context: ToolContext,
) -> ToolResult:
"""
Execute a memory tool.
Args:
tool_name: Name of the tool to execute
arguments: Tool arguments
context: Execution context with project/agent scoping
Returns:
Result of the tool execution
"""
start_time = datetime.now(UTC)
# Get tool definition
tool_def = get_tool_definition(tool_name)
if tool_def is None:
return ToolResult(
success=False,
error=f"Unknown tool: {tool_name}",
error_code="UNKNOWN_TOOL",
)
# Validate arguments
try:
validated_args = tool_def.validate_args(arguments)
except ValidationError as e:
return ToolResult(
success=False,
error=f"Invalid arguments: {e}",
error_code="VALIDATION_ERROR",
)
# Execute the tool
try:
result = await self._dispatch_tool(tool_name, validated_args, context)
execution_time = (datetime.now(UTC) - start_time).total_seconds() * 1000
return ToolResult(
success=True,
data=result,
execution_time_ms=execution_time,
)
except PermissionError as e:
return ToolResult(
success=False,
error=str(e),
error_code="PERMISSION_DENIED",
)
except ValueError as e:
return ToolResult(
success=False,
error=str(e),
error_code="INVALID_VALUE",
)
except Exception as e:
logger.exception(f"Tool execution failed: {tool_name}")
execution_time = (datetime.now(UTC) - start_time).total_seconds() * 1000
return ToolResult(
success=False,
error=f"Execution failed: {type(e).__name__}",
error_code="EXECUTION_ERROR",
execution_time_ms=execution_time,
)
async def _dispatch_tool(
self,
tool_name: str,
args: Any,
context: ToolContext,
) -> Any:
"""Dispatch to the appropriate tool handler."""
handlers: dict[str, Any] = {
"remember": self._execute_remember,
"recall": self._execute_recall,
"forget": self._execute_forget,
"reflect": self._execute_reflect,
"get_memory_stats": self._execute_get_memory_stats,
"search_procedures": self._execute_search_procedures,
"record_outcome": self._execute_record_outcome,
}
handler = handlers.get(tool_name)
if handler is None:
raise ValueError(f"No handler for tool: {tool_name}")
return await handler(args, context)
# =========================================================================
# Tool Handlers
# =========================================================================
async def _execute_remember(
self,
args: RememberArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Execute the 'remember' tool."""
memory_type = args.memory_type
if memory_type == MemoryType.WORKING:
return await self._remember_working(args, context)
elif memory_type == MemoryType.EPISODIC:
return await self._remember_episodic(args, context)
elif memory_type == MemoryType.SEMANTIC:
return await self._remember_semantic(args, context)
elif memory_type == MemoryType.PROCEDURAL:
return await self._remember_procedural(args, context)
else:
raise ValueError(f"Unknown memory type: {memory_type}")
async def _remember_working(
self,
args: RememberArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Store in working memory."""
if not args.key:
raise ValueError("Key is required for working memory")
if not context.session_id:
raise ValueError("Session ID is required for working memory")
working = await self._get_working(
context.session_id,
context.project_id,
context.agent_instance_id,
)
await working.set(
key=args.key,
value=args.content,
ttl_seconds=args.ttl_seconds,
)
return {
"stored": True,
"memory_type": "working",
"key": args.key,
"ttl_seconds": args.ttl_seconds,
}
async def _remember_episodic(
self,
args: RememberArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Store in episodic memory."""
episodic = await self._get_episodic()
episode_data = EpisodeCreate(
project_id=context.project_id,
agent_instance_id=context.agent_instance_id,
session_id=context.session_id or "unknown",
task_type=args.metadata.get("task_type", "manual_entry"),
task_description=args.content[:500],
outcome=Outcome.SUCCESS, # Default, can be updated later
outcome_details="Manual memory entry",
actions=[
{
"type": "manual_entry",
"content": args.content,
"importance": args.importance,
"metadata": args.metadata,
}
],
context_summary=str(args.metadata) if args.metadata else "",
duration_seconds=0.0,
tokens_used=0,
importance_score=args.importance,
)
episode = await episodic.record_episode(episode_data)
return {
"stored": True,
"memory_type": "episodic",
"episode_id": str(episode.id),
"importance": args.importance,
}
async def _remember_semantic(
self,
args: RememberArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Store in semantic memory."""
if not args.subject or not args.predicate or not args.object_value:
raise ValueError(
"Subject, predicate, and object_value are required for semantic memory"
)
semantic = await self._get_semantic()
fact_data = FactCreate(
project_id=context.project_id,
subject=args.subject,
predicate=args.predicate,
object=args.object_value,
confidence=args.importance,
)
fact = await semantic.store_fact(fact_data)
return {
"stored": True,
"memory_type": "semantic",
"fact_id": str(fact.id),
"triple": f"{args.subject} {args.predicate} {args.object_value}",
}
async def _remember_procedural(
self,
args: RememberArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Store in procedural memory."""
if not args.trigger:
raise ValueError("Trigger is required for procedural memory")
if not args.steps:
raise ValueError("Steps are required for procedural memory")
procedural = await self._get_procedural()
procedure_data = ProcedureCreate(
project_id=context.project_id,
agent_type_id=context.agent_type_id,
name=args.content[:100], # Use content as name
trigger_pattern=args.trigger,
steps=args.steps,
)
procedure = await procedural.record_procedure(procedure_data)
return {
"stored": True,
"memory_type": "procedural",
"procedure_id": str(procedure.id),
"trigger": args.trigger,
"steps_count": len(args.steps),
}
async def _execute_recall(
self,
args: RecallArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Execute the 'recall' tool - retrieve memories."""
results: list[dict[str, Any]] = []
for memory_type in args.memory_types:
if memory_type == MemoryType.WORKING:
if context.session_id:
working = await self._get_working(
context.session_id,
context.project_id,
context.agent_instance_id,
)
# Get all keys and filter by query
all_keys = await working.list_keys()
for key in all_keys:
if args.query.lower() in key.lower():
value = await working.get(key)
if value is not None:
results.append(
{
"type": "working",
"key": key,
"content": str(value),
"relevance": 1.0,
}
)
elif memory_type == MemoryType.EPISODIC:
episodic = await self._get_episodic()
episodes = await episodic.search_similar(
project_id=context.project_id,
query=args.query,
limit=args.limit,
agent_instance_id=context.agent_instance_id,
)
for episode in episodes:
results.append(
{
"type": "episodic",
"id": str(episode.id),
"summary": episode.task_description,
"outcome": episode.outcome.value
if episode.outcome
else None,
"occurred_at": episode.occurred_at.isoformat(),
"relevance": episode.importance_score,
}
)
elif memory_type == MemoryType.SEMANTIC:
semantic = await self._get_semantic()
facts = await semantic.search_facts(
query=args.query,
project_id=context.project_id,
limit=args.limit,
min_confidence=args.min_relevance,
)
for fact in facts:
results.append(
{
"type": "semantic",
"id": str(fact.id),
"subject": fact.subject,
"predicate": fact.predicate,
"object": fact.object,
"confidence": fact.confidence,
"relevance": fact.confidence,
}
)
elif memory_type == MemoryType.PROCEDURAL:
procedural = await self._get_procedural()
procedures = await procedural.find_matching(
context=args.query,
project_id=context.project_id,
agent_type_id=context.agent_type_id,
limit=args.limit,
)
for proc in procedures:
results.append(
{
"type": "procedural",
"id": str(proc.id),
"name": proc.name,
"trigger": proc.trigger_pattern,
"success_rate": proc.success_rate,
"steps_count": len(proc.steps) if proc.steps else 0,
"relevance": proc.success_rate,
}
)
# Sort by relevance and limit
results.sort(key=lambda x: x.get("relevance", 0), reverse=True)
results = results[: args.limit]
return {
"query": args.query,
"total_results": len(results),
"results": results,
}
async def _execute_forget(
self,
args: ForgetArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Execute the 'forget' tool - remove memories."""
deleted_count = 0
memory_type = args.memory_type
if memory_type == MemoryType.WORKING:
if not context.session_id:
raise ValueError("Session ID required for working memory")
working = await self._get_working(
context.session_id,
context.project_id,
context.agent_instance_id,
)
if args.key:
deleted = await working.delete(args.key)
deleted_count = 1 if deleted else 0
elif args.pattern:
if not args.confirm_bulk:
raise ValueError("confirm_bulk must be True for pattern deletion")
# Get all keys matching pattern
all_keys = await working.list_keys()
import fnmatch
for key in all_keys:
if fnmatch.fnmatch(key, args.pattern):
await working.delete(key)
deleted_count += 1
elif memory_type == MemoryType.EPISODIC:
if args.memory_id:
episodic = await self._get_episodic()
deleted = await episodic.delete(UUID(args.memory_id))
deleted_count = 1 if deleted else 0
else:
raise ValueError("memory_id required for episodic deletion")
elif memory_type == MemoryType.SEMANTIC:
if args.memory_id:
semantic = await self._get_semantic()
deleted = await semantic.delete(UUID(args.memory_id))
deleted_count = 1 if deleted else 0
else:
raise ValueError("memory_id required for semantic deletion")
elif memory_type == MemoryType.PROCEDURAL:
if args.memory_id:
procedural = await self._get_procedural()
deleted = await procedural.delete(UUID(args.memory_id))
deleted_count = 1 if deleted else 0
else:
raise ValueError("memory_id required for procedural deletion")
return {
"deleted": deleted_count > 0,
"deleted_count": deleted_count,
"memory_type": memory_type.value,
}
async def _execute_reflect(
self,
args: ReflectArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Execute the 'reflect' tool - analyze patterns."""
analysis_type = args.analysis_type
episodic = await self._get_episodic()
# Get recent episodes for analysis
since = datetime.now(UTC) - timedelta(days=30)
recent_episodes = await episodic.get_recent(
project_id=context.project_id,
limit=args.max_items * 5, # Get more for analysis
since=since,
)
if analysis_type == AnalysisType.RECENT_PATTERNS:
return await self._analyze_recent_patterns(recent_episodes, args)
elif analysis_type == AnalysisType.SUCCESS_FACTORS:
return await self._analyze_success_factors(recent_episodes, args)
elif analysis_type == AnalysisType.FAILURE_PATTERNS:
return await self._analyze_failure_patterns(recent_episodes, args)
elif analysis_type == AnalysisType.COMMON_PROCEDURES:
procedural = await self._get_procedural()
return await self._analyze_common_procedures(
procedural, context.project_id, args
)
elif analysis_type == AnalysisType.LEARNING_PROGRESS:
semantic = await self._get_semantic()
return await self._analyze_learning_progress(
semantic, context.project_id, args
)
else:
raise ValueError(f"Unknown analysis type: {analysis_type}")
async def _analyze_recent_patterns(
self,
episodes: list[Any],
args: ReflectArgs,
) -> dict[str, Any]:
"""Analyze patterns in recent episodes."""
# Group by task type
task_types: dict[str, int] = {}
outcomes: dict[str, int] = {}
for ep in episodes:
if ep.task_type:
task_types[ep.task_type] = task_types.get(ep.task_type, 0) + 1
if ep.outcome:
outcome_val = (
ep.outcome.value
if hasattr(ep.outcome, "value")
else str(ep.outcome)
)
outcomes[outcome_val] = outcomes.get(outcome_val, 0) + 1
# Sort by frequency
top_tasks = sorted(task_types.items(), key=lambda x: x[1], reverse=True)[
: args.max_items
]
outcome_dist = dict(outcomes)
examples = []
if args.include_examples:
for ep in episodes[: min(3, args.max_items)]:
examples.append(
{
"summary": ep.task_description,
"task_type": ep.task_type,
"outcome": ep.outcome.value if ep.outcome else None,
}
)
return {
"analysis_type": "recent_patterns",
"total_episodes": len(episodes),
"top_task_types": top_tasks,
"outcome_distribution": outcome_dist,
"examples": examples,
"insights": self._generate_pattern_insights(top_tasks, outcome_dist),
}
async def _analyze_success_factors(
self,
episodes: list[Any],
args: ReflectArgs,
) -> dict[str, Any]:
"""Analyze what leads to success."""
successful = [ep for ep in episodes if ep.outcome == Outcome.SUCCESS]
all_count = len(episodes)
success_count = len(successful)
# Find common patterns in successful episodes
task_success: dict[str, dict[str, int]] = {}
for ep in episodes:
if ep.task_type:
if ep.task_type not in task_success:
task_success[ep.task_type] = {"success": 0, "total": 0}
task_success[ep.task_type]["total"] += 1
if ep.outcome == Outcome.SUCCESS:
task_success[ep.task_type]["success"] += 1
# Calculate success rates
success_rates = {
task: data["success"] / data["total"]
for task, data in task_success.items()
if data["total"] >= 2
}
top_success = sorted(success_rates.items(), key=lambda x: x[1], reverse=True)[
: args.max_items
]
examples = []
if args.include_examples:
for ep in successful[: min(3, args.max_items)]:
examples.append(
{
"summary": ep.task_description,
"task_type": ep.task_type,
"lessons": ep.lessons_learned,
}
)
return {
"analysis_type": "success_factors",
"overall_success_rate": success_count / all_count if all_count > 0 else 0,
"tasks_by_success_rate": top_success,
"successful_examples": examples,
"recommendations": self._generate_success_recommendations(
top_success, success_count, all_count
),
}
async def _analyze_failure_patterns(
self,
episodes: list[Any],
args: ReflectArgs,
) -> dict[str, Any]:
"""Analyze what causes failures."""
failed = [ep for ep in episodes if ep.outcome == Outcome.FAILURE]
# Group failures by task type
failure_by_task: dict[str, list[Any]] = {}
for ep in failed:
task = ep.task_type or "unknown"
if task not in failure_by_task:
failure_by_task[task] = []
failure_by_task[task].append(ep)
# Most common failure types
failure_counts = {task: len(eps) for task, eps in failure_by_task.items()}
top_failures = sorted(failure_counts.items(), key=lambda x: x[1], reverse=True)[
: args.max_items
]
examples = []
if args.include_examples:
for ep in failed[: min(3, args.max_items)]:
examples.append(
{
"summary": ep.task_description,
"task_type": ep.task_type,
"lessons": ep.lessons_learned,
"error": ep.outcome_details,
}
)
return {
"analysis_type": "failure_patterns",
"total_failures": len(failed),
"failures_by_task": top_failures,
"failure_examples": examples,
"prevention_tips": self._generate_failure_prevention(top_failures),
}
async def _analyze_common_procedures(
self,
procedural: ProceduralMemory,
project_id: UUID,
args: ReflectArgs,
) -> dict[str, Any]:
"""Analyze most commonly used procedures."""
# Get procedures by matching against empty context for broad search
# This gets procedures sorted by success rate (see find_matching implementation)
procedures = await procedural.find_matching(
context="", # Empty context gets all procedures
project_id=project_id,
limit=args.max_items * 3,
)
# Sort by execution count (success_count + failure_count)
sorted_procs = sorted(
procedures,
key=lambda p: p.success_count + p.failure_count,
reverse=True,
)[: args.max_items]
return {
"analysis_type": "common_procedures",
"total_procedures": len(procedures),
"top_procedures": [
{
"name": p.name,
"trigger": p.trigger_pattern,
"execution_count": p.success_count + p.failure_count,
"success_rate": p.success_rate,
}
for p in sorted_procs
],
}
async def _analyze_learning_progress(
self,
semantic: SemanticMemory,
project_id: UUID,
args: ReflectArgs,
) -> dict[str, Any]:
"""Analyze learning progress over time."""
# Get facts by searching with empty query (gets recent facts ordered by confidence)
facts = await semantic.search_facts(
query="",
project_id=project_id,
limit=args.max_items * 10,
)
# Group by subject for topic analysis
subjects: dict[str, int] = {}
for fact in facts:
subjects[fact.subject] = subjects.get(fact.subject, 0) + 1
top_subjects = sorted(subjects.items(), key=lambda x: x[1], reverse=True)[
: args.max_items
]
return {
"analysis_type": "learning_progress",
"total_facts_learned": len(facts),
"top_knowledge_areas": top_subjects,
"knowledge_breadth": len(subjects),
}
def _generate_pattern_insights(
self,
top_tasks: list[tuple[str, int]],
outcome_dist: dict[str, int],
) -> list[str]:
"""Generate insights from patterns."""
insights = []
if top_tasks:
insights.append(
f"Most common task type: {top_tasks[0][0]} ({top_tasks[0][1]} occurrences)"
)
total = sum(outcome_dist.values())
if total > 0:
success_rate = outcome_dist.get("success", 0) / total
if success_rate > 0.8:
insights.append(
"High success rate observed - current approach is working well"
)
elif success_rate < 0.5:
insights.append(
"Success rate below 50% - consider reviewing procedures"
)
return insights
def _generate_success_recommendations(
self,
top_success: list[tuple[str, float]],
success_count: int,
total: int,
) -> list[str]:
"""Generate recommendations based on success analysis."""
recommendations = []
if top_success:
best_task, rate = top_success[0]
recommendations.append(
f"'{best_task}' has highest success rate ({rate:.0%}) - analyze why"
)
if total > 0:
overall = success_count / total
if overall < 0.7:
recommendations.append(
"Consider breaking complex tasks into smaller steps"
)
return recommendations
def _generate_failure_prevention(
self,
top_failures: list[tuple[str, int]],
) -> list[str]:
"""Generate tips for preventing failures."""
tips = []
if top_failures:
worst_task, count = top_failures[0]
tips.append(
f"'{worst_task}' has most failures ({count}) - needs procedure review"
)
tips.append(
"Review lessons_learned from past failures before attempting similar tasks"
)
return tips
async def _execute_get_memory_stats(
self,
args: GetMemoryStatsArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Execute the 'get_memory_stats' tool."""
stats: dict[str, Any] = {}
if args.include_breakdown:
# Get counts from each memory type
episodic = await self._get_episodic()
semantic = await self._get_semantic()
procedural = await self._get_procedural()
# Count episodes (get recent with high limit)
episodes = await episodic.get_recent(
project_id=context.project_id,
limit=10000, # Just counting
)
# Count facts (search with empty query)
facts = await semantic.search_facts(
query="",
project_id=context.project_id,
limit=10000,
)
# Count procedures (find_matching with empty context)
procedures = await procedural.find_matching(
context="",
project_id=context.project_id,
limit=10000,
)
# Working memory (session-specific)
working_count = 0
if context.session_id:
working = await self._get_working(
context.session_id,
context.project_id,
context.agent_instance_id,
)
keys = await working.list_keys()
working_count = len(keys)
stats["breakdown"] = {
"working": working_count,
"episodic": len(episodes),
"semantic": len(facts),
"procedural": len(procedures),
"total": working_count + len(episodes) + len(facts) + len(procedures),
}
if args.include_recent_activity:
since = datetime.now(UTC) - timedelta(days=args.time_range_days)
episodic = await self._get_episodic()
recent_episodes = await episodic.get_recent(
project_id=context.project_id,
limit=1000,
since=since,
)
# Calculate activity metrics
outcomes = {"success": 0, "failure": 0, "partial": 0, "abandoned": 0}
for ep in recent_episodes:
if ep.outcome:
key = (
ep.outcome.value
if hasattr(ep.outcome, "value")
else str(ep.outcome)
)
if key in outcomes:
outcomes[key] += 1
stats["recent_activity"] = {
"time_range_days": args.time_range_days,
"episodes_created": len(recent_episodes),
"outcomes": outcomes,
}
return stats
async def _execute_search_procedures(
self,
args: SearchProceduresArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Execute the 'search_procedures' tool."""
procedural = await self._get_procedural()
# Use find_matching with trigger as context
all_procedures = await procedural.find_matching(
context=args.trigger,
project_id=context.project_id,
agent_type_id=context.agent_type_id,
limit=args.limit * 3, # Get more to filter by success rate
)
# Filter by minimum success rate if specified
procedures = [
p
for p in all_procedures
if args.min_success_rate is None or p.success_rate >= args.min_success_rate
][: args.limit]
results = []
for proc in procedures:
proc_data: dict[str, Any] = {
"id": str(proc.id),
"name": proc.name,
"trigger": proc.trigger_pattern,
"success_rate": proc.success_rate,
"execution_count": proc.success_count + proc.failure_count,
}
if args.include_steps:
proc_data["steps"] = proc.steps
results.append(proc_data)
return {
"trigger": args.trigger,
"procedures_found": len(results),
"procedures": results,
}
async def _execute_record_outcome(
self,
args: RecordOutcomeArgs,
context: ToolContext,
) -> dict[str, Any]:
"""Execute the 'record_outcome' tool."""
# OutcomeType is now an alias for Outcome, use directly
outcome = args.outcome
# Record in episodic memory
episodic = await self._get_episodic()
episode_data = EpisodeCreate(
project_id=context.project_id,
agent_instance_id=context.agent_instance_id,
session_id=context.session_id or "unknown",
task_type=args.task_type,
task_description=f"Task outcome: {args.outcome.value}",
outcome=outcome,
outcome_details=args.error_details or "No additional details",
actions=[
{
"type": "outcome_record",
"outcome": args.outcome.value,
"context": args.context,
}
],
context_summary=str(args.context) if args.context else "",
lessons_learned=[args.lessons_learned] if args.lessons_learned else [],
duration_seconds=args.duration_seconds or 0.0,
tokens_used=0,
)
episode = await episodic.record_episode(episode_data)
# Update procedure success rate if procedure_id provided
procedure_updated = False
if args.procedure_id:
procedural = await self._get_procedural()
try:
await procedural.record_outcome(
procedure_id=UUID(args.procedure_id),
success=args.outcome == OutcomeType.SUCCESS,
)
procedure_updated = True
except Exception as e:
logger.warning(f"Failed to update procedure outcome: {e}")
return {
"recorded": True,
"episode_id": str(episode.id),
"outcome": args.outcome.value,
"procedure_updated": procedure_updated,
}
# Factory function for dependency injection
async def get_memory_tool_service(
session: AsyncSession,
embedding_generator: Any | None = None,
) -> MemoryToolService:
"""Create a memory tool service instance."""
return MemoryToolService(
session=session,
embedding_generator=embedding_generator,
)