diff --git a/backend/app/services/memory/__init__.py b/backend/app/services/memory/__init__.py index 2627efe..e0f6f47 100644 --- a/backend/app/services/memory/__init__.py +++ b/backend/app/services/memory/__init__.py @@ -133,4 +133,6 @@ __all__ = [ "get_default_settings", "get_memory_settings", "reset_memory_settings", + # MCP Tools - lazy import to avoid circular dependencies + # Import directly: from app.services.memory.mcp import MemoryToolService ] diff --git a/backend/app/services/memory/mcp/__init__.py b/backend/app/services/memory/mcp/__init__.py new file mode 100644 index 0000000..973088f --- /dev/null +++ b/backend/app/services/memory/mcp/__init__.py @@ -0,0 +1,40 @@ +# app/services/memory/mcp/__init__.py +""" +MCP Tools for Agent Memory System. + +Exposes memory operations as MCP-compatible tools that agents can invoke: +- remember: Store data in memory +- recall: Retrieve from memory +- forget: Remove from memory +- reflect: Analyze patterns +- get_memory_stats: Usage statistics +- search_procedures: Find relevant procedures +- record_outcome: Record task success/failure +""" + +from .service import MemoryToolService, get_memory_tool_service +from .tools import ( + MEMORY_TOOL_DEFINITIONS, + ForgetArgs, + GetMemoryStatsArgs, + MemoryToolDefinition, + RecallArgs, + RecordOutcomeArgs, + ReflectArgs, + RememberArgs, + SearchProceduresArgs, +) + +__all__ = [ + "MEMORY_TOOL_DEFINITIONS", + "ForgetArgs", + "GetMemoryStatsArgs", + "MemoryToolDefinition", + "MemoryToolService", + "RecallArgs", + "RecordOutcomeArgs", + "ReflectArgs", + "RememberArgs", + "SearchProceduresArgs", + "get_memory_tool_service", +] diff --git a/backend/app/services/memory/mcp/service.py b/backend/app/services/memory/mcp/service.py new file mode 100644 index 0000000..03d8d29 --- /dev/null +++ b/backend/app/services/memory/mcp/service.py @@ -0,0 +1,1042 @@ +# app/services/memory/mcp/service.py +""" +MCP Tool Service for Agent Memory System. + +Executes memory tool calls from agents, routing to appropriate memory services. +All tools are scoped to project/agent context for proper isolation. +""" + +import logging +from dataclasses import dataclass +from datetime import UTC, datetime, timedelta +from typing import Any +from uuid import UUID + +from pydantic import ValidationError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.memory.episodic import EpisodicMemory +from app.services.memory.procedural import ProceduralMemory +from app.services.memory.semantic import SemanticMemory +from app.services.memory.types import ( + EpisodeCreate, + FactCreate, + Outcome, + ProcedureCreate, +) +from app.services.memory.working import WorkingMemory + +from .tools import ( + AnalysisType, + ForgetArgs, + GetMemoryStatsArgs, + MemoryType, + OutcomeType, + RecallArgs, + RecordOutcomeArgs, + ReflectArgs, + RememberArgs, + SearchProceduresArgs, + get_tool_definition, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolContext: + """Context for tool execution - provides scoping information.""" + + project_id: UUID + agent_instance_id: UUID | None = None + agent_type_id: UUID | None = None + session_id: str | None = None + user_id: UUID | None = None + + +@dataclass +class ToolResult: + """Result from a tool execution.""" + + success: bool + data: Any = None + error: str | None = None + error_code: str | None = None + execution_time_ms: float = 0.0 + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "success": self.success, + "data": self.data, + "error": self.error, + "error_code": self.error_code, + "execution_time_ms": self.execution_time_ms, + } + + +class MemoryToolService: + """ + Service for executing memory-related MCP tool calls. + + All operations are scoped to the project/agent context provided. + This service coordinates between different memory types. + """ + + def __init__( + self, + session: AsyncSession, + embedding_generator: Any | None = None, + ) -> None: + """ + Initialize the memory tool service. + + Args: + session: Database session for memory operations + embedding_generator: Optional embedding generator for semantic search + """ + self._session = session + self._embedding_generator = embedding_generator + + # Lazy-initialized memory services + self._working: dict[str, WorkingMemory] = {} # keyed by session_id + self._episodic: EpisodicMemory | None = None + self._semantic: SemanticMemory | None = None + self._procedural: ProceduralMemory | None = None + + async def _get_working( + self, + session_id: str, + project_id: UUID | None = None, + agent_instance_id: UUID | None = None, + ) -> WorkingMemory: + """Get or create working memory for a session.""" + if session_id not in self._working: + self._working[session_id] = await WorkingMemory.for_session( + session_id=session_id, + project_id=str(project_id) if project_id else None, + agent_instance_id=str(agent_instance_id) if agent_instance_id else None, + ) + return self._working[session_id] + + async def _get_episodic(self) -> EpisodicMemory: + """Get or create episodic memory service.""" + if self._episodic is None: + self._episodic = await EpisodicMemory.create( + self._session, + self._embedding_generator, + ) + return self._episodic + + async def _get_semantic(self) -> SemanticMemory: + """Get or create semantic memory service.""" + if self._semantic is None: + self._semantic = await SemanticMemory.create( + self._session, + self._embedding_generator, + ) + return self._semantic + + async def _get_procedural(self) -> ProceduralMemory: + """Get or create procedural memory service.""" + if self._procedural is None: + self._procedural = await ProceduralMemory.create( + self._session, + self._embedding_generator, + ) + return self._procedural + + async def execute_tool( + self, + tool_name: str, + arguments: dict[str, Any], + context: ToolContext, + ) -> ToolResult: + """ + Execute a memory tool. + + Args: + tool_name: Name of the tool to execute + arguments: Tool arguments + context: Execution context with project/agent scoping + + Returns: + Result of the tool execution + """ + start_time = datetime.now(UTC) + + # Get tool definition + tool_def = get_tool_definition(tool_name) + if tool_def is None: + return ToolResult( + success=False, + error=f"Unknown tool: {tool_name}", + error_code="UNKNOWN_TOOL", + ) + + # Validate arguments + try: + validated_args = tool_def.validate_args(arguments) + except ValidationError as e: + return ToolResult( + success=False, + error=f"Invalid arguments: {e}", + error_code="VALIDATION_ERROR", + ) + + # Execute the tool + try: + result = await self._dispatch_tool(tool_name, validated_args, context) + execution_time = (datetime.now(UTC) - start_time).total_seconds() * 1000 + + return ToolResult( + success=True, + data=result, + execution_time_ms=execution_time, + ) + + except PermissionError as e: + return ToolResult( + success=False, + error=str(e), + error_code="PERMISSION_DENIED", + ) + except ValueError as e: + return ToolResult( + success=False, + error=str(e), + error_code="INVALID_VALUE", + ) + except Exception as e: + logger.exception(f"Tool execution failed: {tool_name}") + execution_time = (datetime.now(UTC) - start_time).total_seconds() * 1000 + return ToolResult( + success=False, + error=f"Execution failed: {type(e).__name__}", + error_code="EXECUTION_ERROR", + execution_time_ms=execution_time, + ) + + async def _dispatch_tool( + self, + tool_name: str, + args: Any, + context: ToolContext, + ) -> Any: + """Dispatch to the appropriate tool handler.""" + handlers: dict[str, Any] = { + "remember": self._execute_remember, + "recall": self._execute_recall, + "forget": self._execute_forget, + "reflect": self._execute_reflect, + "get_memory_stats": self._execute_get_memory_stats, + "search_procedures": self._execute_search_procedures, + "record_outcome": self._execute_record_outcome, + } + + handler = handlers.get(tool_name) + if handler is None: + raise ValueError(f"No handler for tool: {tool_name}") + + return await handler(args, context) + + # ========================================================================= + # Tool Handlers + # ========================================================================= + + async def _execute_remember( + self, + args: RememberArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Execute the 'remember' tool.""" + memory_type = args.memory_type + + if memory_type == MemoryType.WORKING: + return await self._remember_working(args, context) + elif memory_type == MemoryType.EPISODIC: + return await self._remember_episodic(args, context) + elif memory_type == MemoryType.SEMANTIC: + return await self._remember_semantic(args, context) + elif memory_type == MemoryType.PROCEDURAL: + return await self._remember_procedural(args, context) + else: + raise ValueError(f"Unknown memory type: {memory_type}") + + async def _remember_working( + self, + args: RememberArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Store in working memory.""" + if not args.key: + raise ValueError("Key is required for working memory") + if not context.session_id: + raise ValueError("Session ID is required for working memory") + + working = await self._get_working( + context.session_id, + context.project_id, + context.agent_instance_id, + ) + await working.set( + key=args.key, + value=args.content, + ttl_seconds=args.ttl_seconds, + ) + + return { + "stored": True, + "memory_type": "working", + "key": args.key, + "ttl_seconds": args.ttl_seconds, + } + + async def _remember_episodic( + self, + args: RememberArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Store in episodic memory.""" + episodic = await self._get_episodic() + + episode_data = EpisodeCreate( + project_id=context.project_id, + agent_instance_id=context.agent_instance_id, + session_id=context.session_id or "unknown", + task_type=args.metadata.get("task_type", "manual_entry"), + task_description=args.content[:500], + outcome=Outcome.SUCCESS, # Default, can be updated later + outcome_details="Manual memory entry", + actions=[ + { + "type": "manual_entry", + "content": args.content, + "importance": args.importance, + "metadata": args.metadata, + } + ], + context_summary=str(args.metadata) if args.metadata else "", + duration_seconds=0.0, + tokens_used=0, + importance_score=args.importance, + ) + + episode = await episodic.record_episode(episode_data) + + return { + "stored": True, + "memory_type": "episodic", + "episode_id": str(episode.id), + "importance": args.importance, + } + + async def _remember_semantic( + self, + args: RememberArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Store in semantic memory.""" + if not args.subject or not args.predicate or not args.object_value: + raise ValueError( + "Subject, predicate, and object_value are required for semantic memory" + ) + + semantic = await self._get_semantic() + + fact_data = FactCreate( + project_id=context.project_id, + subject=args.subject, + predicate=args.predicate, + object=args.object_value, + confidence=args.importance, + ) + + fact = await semantic.store_fact(fact_data) + + return { + "stored": True, + "memory_type": "semantic", + "fact_id": str(fact.id), + "triple": f"{args.subject} {args.predicate} {args.object_value}", + } + + async def _remember_procedural( + self, + args: RememberArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Store in procedural memory.""" + if not args.trigger: + raise ValueError("Trigger is required for procedural memory") + if not args.steps: + raise ValueError("Steps are required for procedural memory") + + procedural = await self._get_procedural() + + procedure_data = ProcedureCreate( + project_id=context.project_id, + agent_type_id=context.agent_type_id, + name=args.content[:100], # Use content as name + trigger_pattern=args.trigger, + steps=args.steps, + ) + + procedure = await procedural.record_procedure(procedure_data) + + return { + "stored": True, + "memory_type": "procedural", + "procedure_id": str(procedure.id), + "trigger": args.trigger, + "steps_count": len(args.steps), + } + + async def _execute_recall( + self, + args: RecallArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Execute the 'recall' tool - retrieve memories.""" + results: list[dict[str, Any]] = [] + + for memory_type in args.memory_types: + if memory_type == MemoryType.WORKING: + if context.session_id: + working = await self._get_working( + context.session_id, + context.project_id, + context.agent_instance_id, + ) + # Get all keys and filter by query + all_keys = await working.list_keys() + for key in all_keys: + if args.query.lower() in key.lower(): + value = await working.get(key) + if value is not None: + results.append({ + "type": "working", + "key": key, + "content": str(value), + "relevance": 1.0, + }) + + elif memory_type == MemoryType.EPISODIC: + episodic = await self._get_episodic() + episodes = await episodic.search_similar( + project_id=context.project_id, + query=args.query, + limit=args.limit, + agent_instance_id=context.agent_instance_id, + ) + for episode in episodes: + results.append({ + "type": "episodic", + "id": str(episode.id), + "summary": episode.task_description, + "outcome": episode.outcome.value if episode.outcome else None, + "occurred_at": episode.occurred_at.isoformat(), + "relevance": episode.importance_score, + }) + + elif memory_type == MemoryType.SEMANTIC: + semantic = await self._get_semantic() + facts = await semantic.search_facts( + query=args.query, + project_id=context.project_id, + limit=args.limit, + min_confidence=args.min_relevance, + ) + for fact in facts: + results.append({ + "type": "semantic", + "id": str(fact.id), + "subject": fact.subject, + "predicate": fact.predicate, + "object": fact.object, + "confidence": fact.confidence, + "relevance": fact.confidence, + }) + + elif memory_type == MemoryType.PROCEDURAL: + procedural = await self._get_procedural() + procedures = await procedural.find_matching( + context=args.query, + project_id=context.project_id, + agent_type_id=context.agent_type_id, + limit=args.limit, + ) + for proc in procedures: + results.append({ + "type": "procedural", + "id": str(proc.id), + "name": proc.name, + "trigger": proc.trigger_pattern, + "success_rate": proc.success_rate, + "steps_count": len(proc.steps) if proc.steps else 0, + "relevance": proc.success_rate, + }) + + # Sort by relevance and limit + results.sort(key=lambda x: x.get("relevance", 0), reverse=True) + results = results[: args.limit] + + return { + "query": args.query, + "total_results": len(results), + "results": results, + } + + async def _execute_forget( + self, + args: ForgetArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Execute the 'forget' tool - remove memories.""" + deleted_count = 0 + memory_type = args.memory_type + + if memory_type == MemoryType.WORKING: + if not context.session_id: + raise ValueError("Session ID required for working memory") + + working = await self._get_working( + context.session_id, + context.project_id, + context.agent_instance_id, + ) + + if args.key: + deleted = await working.delete(args.key) + deleted_count = 1 if deleted else 0 + elif args.pattern: + if not args.confirm_bulk: + raise ValueError("confirm_bulk must be True for pattern deletion") + # Get all keys matching pattern + all_keys = await working.list_keys() + import fnmatch + + for key in all_keys: + if fnmatch.fnmatch(key, args.pattern): + await working.delete(key) + deleted_count += 1 + + elif memory_type == MemoryType.EPISODIC: + if args.memory_id: + episodic = await self._get_episodic() + deleted = await episodic.delete(UUID(args.memory_id)) + deleted_count = 1 if deleted else 0 + else: + raise ValueError("memory_id required for episodic deletion") + + elif memory_type == MemoryType.SEMANTIC: + if args.memory_id: + semantic = await self._get_semantic() + deleted = await semantic.delete(UUID(args.memory_id)) + deleted_count = 1 if deleted else 0 + else: + raise ValueError("memory_id required for semantic deletion") + + elif memory_type == MemoryType.PROCEDURAL: + if args.memory_id: + procedural = await self._get_procedural() + deleted = await procedural.delete(UUID(args.memory_id)) + deleted_count = 1 if deleted else 0 + else: + raise ValueError("memory_id required for procedural deletion") + + return { + "deleted": deleted_count > 0, + "deleted_count": deleted_count, + "memory_type": memory_type.value, + } + + async def _execute_reflect( + self, + args: ReflectArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Execute the 'reflect' tool - analyze patterns.""" + analysis_type = args.analysis_type + episodic = await self._get_episodic() + + # Get recent episodes for analysis + since = datetime.now(UTC) - timedelta(days=30) + recent_episodes = await episodic.get_recent( + project_id=context.project_id, + limit=args.max_items * 5, # Get more for analysis + since=since, + ) + + if analysis_type == AnalysisType.RECENT_PATTERNS: + return await self._analyze_recent_patterns(recent_episodes, args) + elif analysis_type == AnalysisType.SUCCESS_FACTORS: + return await self._analyze_success_factors(recent_episodes, args) + elif analysis_type == AnalysisType.FAILURE_PATTERNS: + return await self._analyze_failure_patterns(recent_episodes, args) + elif analysis_type == AnalysisType.COMMON_PROCEDURES: + procedural = await self._get_procedural() + return await self._analyze_common_procedures( + procedural, context.project_id, args + ) + elif analysis_type == AnalysisType.LEARNING_PROGRESS: + semantic = await self._get_semantic() + return await self._analyze_learning_progress( + semantic, context.project_id, args + ) + else: + raise ValueError(f"Unknown analysis type: {analysis_type}") + + async def _analyze_recent_patterns( + self, + episodes: list[Any], + args: ReflectArgs, + ) -> dict[str, Any]: + """Analyze patterns in recent episodes.""" + # Group by task type + task_types: dict[str, int] = {} + outcomes: dict[str, int] = {} + + for ep in episodes: + if ep.task_type: + task_types[ep.task_type] = task_types.get(ep.task_type, 0) + 1 + if ep.outcome: + outcome_val = ep.outcome.value if hasattr(ep.outcome, "value") else str(ep.outcome) + outcomes[outcome_val] = outcomes.get(outcome_val, 0) + 1 + + # Sort by frequency + top_tasks = sorted(task_types.items(), key=lambda x: x[1], reverse=True)[ + : args.max_items + ] + outcome_dist = dict(outcomes) + + examples = [] + if args.include_examples: + for ep in episodes[: min(3, args.max_items)]: + examples.append({ + "summary": ep.task_description, + "task_type": ep.task_type, + "outcome": ep.outcome.value if ep.outcome else None, + }) + + return { + "analysis_type": "recent_patterns", + "total_episodes": len(episodes), + "top_task_types": top_tasks, + "outcome_distribution": outcome_dist, + "examples": examples, + "insights": self._generate_pattern_insights(top_tasks, outcome_dist), + } + + async def _analyze_success_factors( + self, + episodes: list[Any], + args: ReflectArgs, + ) -> dict[str, Any]: + """Analyze what leads to success.""" + successful = [ep for ep in episodes if ep.outcome == Outcome.SUCCESS] + all_count = len(episodes) + success_count = len(successful) + + # Find common patterns in successful episodes + task_success: dict[str, dict[str, int]] = {} + for ep in episodes: + if ep.task_type: + if ep.task_type not in task_success: + task_success[ep.task_type] = {"success": 0, "total": 0} + task_success[ep.task_type]["total"] += 1 + if ep.outcome == Outcome.SUCCESS: + task_success[ep.task_type]["success"] += 1 + + # Calculate success rates + success_rates = { + task: data["success"] / data["total"] + for task, data in task_success.items() + if data["total"] >= 2 + } + top_success = sorted(success_rates.items(), key=lambda x: x[1], reverse=True)[ + : args.max_items + ] + + examples = [] + if args.include_examples: + for ep in successful[: min(3, args.max_items)]: + examples.append({ + "summary": ep.task_description, + "task_type": ep.task_type, + "lessons": ep.lessons_learned, + }) + + return { + "analysis_type": "success_factors", + "overall_success_rate": success_count / all_count if all_count > 0 else 0, + "tasks_by_success_rate": top_success, + "successful_examples": examples, + "recommendations": self._generate_success_recommendations( + top_success, success_count, all_count + ), + } + + async def _analyze_failure_patterns( + self, + episodes: list[Any], + args: ReflectArgs, + ) -> dict[str, Any]: + """Analyze what causes failures.""" + failed = [ep for ep in episodes if ep.outcome == Outcome.FAILURE] + + # Group failures by task type + failure_by_task: dict[str, list[Any]] = {} + for ep in failed: + task = ep.task_type or "unknown" + if task not in failure_by_task: + failure_by_task[task] = [] + failure_by_task[task].append(ep) + + # Most common failure types + failure_counts = { + task: len(eps) for task, eps in failure_by_task.items() + } + top_failures = sorted(failure_counts.items(), key=lambda x: x[1], reverse=True)[ + : args.max_items + ] + + examples = [] + if args.include_examples: + for ep in failed[: min(3, args.max_items)]: + examples.append({ + "summary": ep.task_description, + "task_type": ep.task_type, + "lessons": ep.lessons_learned, + "error": ep.outcome_details, + }) + + return { + "analysis_type": "failure_patterns", + "total_failures": len(failed), + "failures_by_task": top_failures, + "failure_examples": examples, + "prevention_tips": self._generate_failure_prevention(top_failures), + } + + async def _analyze_common_procedures( + self, + procedural: ProceduralMemory, + project_id: UUID, + args: ReflectArgs, + ) -> dict[str, Any]: + """Analyze most commonly used procedures.""" + # Get procedures by matching against empty context for broad search + # This gets procedures sorted by success rate (see find_matching implementation) + procedures = await procedural.find_matching( + context="", # Empty context gets all procedures + project_id=project_id, + limit=args.max_items * 3, + ) + + # Sort by execution count (success_count + failure_count) + sorted_procs = sorted( + procedures, + key=lambda p: p.success_count + p.failure_count, + reverse=True, + )[: args.max_items] + + return { + "analysis_type": "common_procedures", + "total_procedures": len(procedures), + "top_procedures": [ + { + "name": p.name, + "trigger": p.trigger_pattern, + "execution_count": p.success_count + p.failure_count, + "success_rate": p.success_rate, + } + for p in sorted_procs + ], + } + + async def _analyze_learning_progress( + self, + semantic: SemanticMemory, + project_id: UUID, + args: ReflectArgs, + ) -> dict[str, Any]: + """Analyze learning progress over time.""" + # Get facts by searching with empty query (gets recent facts ordered by confidence) + facts = await semantic.search_facts( + query="", + project_id=project_id, + limit=args.max_items * 10, + ) + + # Group by subject for topic analysis + subjects: dict[str, int] = {} + for fact in facts: + subjects[fact.subject] = subjects.get(fact.subject, 0) + 1 + + top_subjects = sorted(subjects.items(), key=lambda x: x[1], reverse=True)[ + : args.max_items + ] + + return { + "analysis_type": "learning_progress", + "total_facts_learned": len(facts), + "top_knowledge_areas": top_subjects, + "knowledge_breadth": len(subjects), + } + + def _generate_pattern_insights( + self, + top_tasks: list[tuple[str, int]], + outcome_dist: dict[str, int], + ) -> list[str]: + """Generate insights from patterns.""" + insights = [] + + if top_tasks: + insights.append(f"Most common task type: {top_tasks[0][0]} ({top_tasks[0][1]} occurrences)") + + total = sum(outcome_dist.values()) + if total > 0: + success_rate = outcome_dist.get("success", 0) / total + if success_rate > 0.8: + insights.append("High success rate observed - current approach is working well") + elif success_rate < 0.5: + insights.append("Success rate below 50% - consider reviewing procedures") + + return insights + + def _generate_success_recommendations( + self, + top_success: list[tuple[str, float]], + success_count: int, + total: int, + ) -> list[str]: + """Generate recommendations based on success analysis.""" + recommendations = [] + + if top_success: + best_task, rate = top_success[0] + recommendations.append( + f"'{best_task}' has highest success rate ({rate:.0%}) - analyze why" + ) + + if total > 0: + overall = success_count / total + if overall < 0.7: + recommendations.append( + "Consider breaking complex tasks into smaller steps" + ) + + return recommendations + + def _generate_failure_prevention( + self, + top_failures: list[tuple[str, int]], + ) -> list[str]: + """Generate tips for preventing failures.""" + tips = [] + + if top_failures: + worst_task, count = top_failures[0] + tips.append(f"'{worst_task}' has most failures ({count}) - needs procedure review") + + tips.append("Review lessons_learned from past failures before attempting similar tasks") + + return tips + + async def _execute_get_memory_stats( + self, + args: GetMemoryStatsArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Execute the 'get_memory_stats' tool.""" + stats: dict[str, Any] = {} + + if args.include_breakdown: + # Get counts from each memory type + episodic = await self._get_episodic() + semantic = await self._get_semantic() + procedural = await self._get_procedural() + + # Count episodes (get recent with high limit) + episodes = await episodic.get_recent( + project_id=context.project_id, + limit=10000, # Just counting + ) + + # Count facts (search with empty query) + facts = await semantic.search_facts( + query="", + project_id=context.project_id, + limit=10000, + ) + + # Count procedures (find_matching with empty context) + procedures = await procedural.find_matching( + context="", + project_id=context.project_id, + limit=10000, + ) + + # Working memory (session-specific) + working_count = 0 + if context.session_id: + working = await self._get_working( + context.session_id, + context.project_id, + context.agent_instance_id, + ) + keys = await working.list_keys() + working_count = len(keys) + + stats["breakdown"] = { + "working": working_count, + "episodic": len(episodes), + "semantic": len(facts), + "procedural": len(procedures), + "total": working_count + len(episodes) + len(facts) + len(procedures), + } + + if args.include_recent_activity: + since = datetime.now(UTC) - timedelta(days=args.time_range_days) + episodic = await self._get_episodic() + + recent_episodes = await episodic.get_recent( + project_id=context.project_id, + limit=1000, + since=since, + ) + + # Calculate activity metrics + outcomes = {"success": 0, "failure": 0, "partial": 0, "abandoned": 0} + for ep in recent_episodes: + if ep.outcome: + key = ep.outcome.value if hasattr(ep.outcome, "value") else str(ep.outcome) + if key in outcomes: + outcomes[key] += 1 + + stats["recent_activity"] = { + "time_range_days": args.time_range_days, + "episodes_created": len(recent_episodes), + "outcomes": outcomes, + } + + return stats + + async def _execute_search_procedures( + self, + args: SearchProceduresArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Execute the 'search_procedures' tool.""" + procedural = await self._get_procedural() + + # Use find_matching with trigger as context + all_procedures = await procedural.find_matching( + context=args.trigger, + project_id=context.project_id, + agent_type_id=context.agent_type_id, + limit=args.limit * 3, # Get more to filter by success rate + ) + + # Filter by minimum success rate if specified + procedures = [ + p for p in all_procedures + if args.min_success_rate is None or p.success_rate >= args.min_success_rate + ][: args.limit] + + results = [] + for proc in procedures: + proc_data: dict[str, Any] = { + "id": str(proc.id), + "name": proc.name, + "trigger": proc.trigger_pattern, + "success_rate": proc.success_rate, + "execution_count": proc.success_count + proc.failure_count, + } + + if args.include_steps: + proc_data["steps"] = proc.steps + + results.append(proc_data) + + return { + "trigger": args.trigger, + "procedures_found": len(results), + "procedures": results, + } + + async def _execute_record_outcome( + self, + args: RecordOutcomeArgs, + context: ToolContext, + ) -> dict[str, Any]: + """Execute the 'record_outcome' tool.""" + # Map outcome type to memory Outcome + # Note: ABANDONED maps to FAILURE since core Outcome doesn't have ABANDONED + outcome_map = { + OutcomeType.SUCCESS: Outcome.SUCCESS, + OutcomeType.PARTIAL: Outcome.PARTIAL, + OutcomeType.FAILURE: Outcome.FAILURE, + OutcomeType.ABANDONED: Outcome.FAILURE, # No ABANDONED in core enum + } + outcome = outcome_map.get(args.outcome, Outcome.FAILURE) + + # Record in episodic memory + episodic = await self._get_episodic() + + episode_data = EpisodeCreate( + project_id=context.project_id, + agent_instance_id=context.agent_instance_id, + session_id=context.session_id or "unknown", + task_type=args.task_type, + task_description=f"Task outcome: {args.outcome.value}", + outcome=outcome, + outcome_details=args.error_details or "No additional details", + actions=[ + { + "type": "outcome_record", + "outcome": args.outcome.value, + "context": args.context, + } + ], + context_summary=str(args.context) if args.context else "", + lessons_learned=[args.lessons_learned] if args.lessons_learned else [], + duration_seconds=args.duration_seconds or 0.0, + tokens_used=0, + ) + + episode = await episodic.record_episode(episode_data) + + # Update procedure success rate if procedure_id provided + procedure_updated = False + if args.procedure_id: + procedural = await self._get_procedural() + try: + await procedural.record_outcome( + procedure_id=UUID(args.procedure_id), + success=args.outcome == OutcomeType.SUCCESS, + ) + procedure_updated = True + except Exception as e: + logger.warning(f"Failed to update procedure outcome: {e}") + + return { + "recorded": True, + "episode_id": str(episode.id), + "outcome": args.outcome.value, + "procedure_updated": procedure_updated, + } + + +# Factory function for dependency injection +async def get_memory_tool_service( + session: AsyncSession, + embedding_generator: Any | None = None, +) -> MemoryToolService: + """Create a memory tool service instance.""" + return MemoryToolService( + session=session, + embedding_generator=embedding_generator, + ) diff --git a/backend/app/services/memory/mcp/tools.py b/backend/app/services/memory/mcp/tools.py new file mode 100644 index 0000000..4db6a5e --- /dev/null +++ b/backend/app/services/memory/mcp/tools.py @@ -0,0 +1,491 @@ +# app/services/memory/mcp/tools.py +""" +MCP Tool Definitions for Agent Memory System. + +Defines the schema and metadata for memory-related MCP tools. +These tools are invoked by AI agents to interact with the memory system. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + + +class MemoryType(str, Enum): + """Types of memory for storage operations.""" + + WORKING = "working" + EPISODIC = "episodic" + SEMANTIC = "semantic" + PROCEDURAL = "procedural" + + +class AnalysisType(str, Enum): + """Types of pattern analysis for the reflect tool.""" + + RECENT_PATTERNS = "recent_patterns" + SUCCESS_FACTORS = "success_factors" + FAILURE_PATTERNS = "failure_patterns" + COMMON_PROCEDURES = "common_procedures" + LEARNING_PROGRESS = "learning_progress" + + +class OutcomeType(str, Enum): + """Outcome types for record_outcome tool.""" + + SUCCESS = "success" + PARTIAL = "partial" + FAILURE = "failure" + ABANDONED = "abandoned" + + +# ============================================================================ +# Tool Argument Schemas (Pydantic models for validation) +# ============================================================================ + + +class RememberArgs(BaseModel): + """Arguments for the 'remember' tool.""" + + memory_type: MemoryType = Field( + ..., + description="Type of memory to store in: working, episodic, semantic, or procedural", + ) + content: str = Field( + ..., + description="The content to remember. Can be text, facts, or procedure steps.", + min_length=1, + max_length=10000, + ) + key: str | None = Field( + None, + description="Optional key for working memory entries. Required for working memory type.", + max_length=256, + ) + importance: float = Field( + 0.5, + description="Importance score from 0.0 (low) to 1.0 (critical)", + ge=0.0, + le=1.0, + ) + ttl_seconds: int | None = Field( + None, + description="Time-to-live in seconds for working memory. None for permanent storage.", + ge=1, + le=86400 * 30, # Max 30 days + ) + metadata: dict[str, Any] = Field( + default_factory=dict, + description="Additional metadata to store with the memory", + ) + # For semantic memory (facts) + subject: str | None = Field( + None, + description="Subject of the fact (for semantic memory)", + max_length=256, + ) + predicate: str | None = Field( + None, + description="Predicate/relationship (for semantic memory)", + max_length=256, + ) + object_value: str | None = Field( + None, + description="Object of the fact (for semantic memory)", + max_length=1000, + ) + # For procedural memory + trigger: str | None = Field( + None, + description="Trigger condition for the procedure (for procedural memory)", + max_length=500, + ) + steps: list[dict[str, Any]] | None = Field( + None, + description="Procedure steps as a list of action dictionaries", + ) + + +class RecallArgs(BaseModel): + """Arguments for the 'recall' tool.""" + + query: str = Field( + ..., + description="Search query to find relevant memories", + min_length=1, + max_length=1000, + ) + memory_types: list[MemoryType] = Field( + default_factory=lambda: [MemoryType.EPISODIC, MemoryType.SEMANTIC], + description="Types of memory to search in", + ) + limit: int = Field( + 10, + description="Maximum number of results to return", + ge=1, + le=100, + ) + min_relevance: float = Field( + 0.0, + description="Minimum relevance score (0.0-1.0) for results", + ge=0.0, + le=1.0, + ) + filters: dict[str, Any] = Field( + default_factory=dict, + description="Additional filters (e.g., outcome, task_type, date range)", + ) + include_context: bool = Field( + True, + description="Whether to include surrounding context in results", + ) + + +class ForgetArgs(BaseModel): + """Arguments for the 'forget' tool.""" + + memory_type: MemoryType = Field( + ..., + description="Type of memory to remove from", + ) + key: str | None = Field( + None, + description="Key to remove (for working memory)", + max_length=256, + ) + memory_id: str | None = Field( + None, + description="Specific memory ID to remove (for episodic/semantic/procedural)", + ) + pattern: str | None = Field( + None, + description="Pattern to match for bulk removal (use with caution)", + max_length=500, + ) + confirm_bulk: bool = Field( + False, + description="Must be True to confirm bulk deletion when using pattern", + ) + + +class ReflectArgs(BaseModel): + """Arguments for the 'reflect' tool.""" + + analysis_type: AnalysisType = Field( + ..., + description="Type of pattern analysis to perform", + ) + scope: str | None = Field( + None, + description="Optional scope to limit analysis (e.g., task_type, time range)", + max_length=500, + ) + depth: int = Field( + 3, + description="Depth of analysis (1=surface, 5=deep)", + ge=1, + le=5, + ) + include_examples: bool = Field( + True, + description="Whether to include example memories in the analysis", + ) + max_items: int = Field( + 10, + description="Maximum number of patterns/examples to analyze", + ge=1, + le=50, + ) + + +class GetMemoryStatsArgs(BaseModel): + """Arguments for the 'get_memory_stats' tool.""" + + include_breakdown: bool = Field( + True, + description="Include breakdown by memory type", + ) + include_recent_activity: bool = Field( + True, + description="Include recent memory activity summary", + ) + time_range_days: int = Field( + 7, + description="Time range for activity analysis in days", + ge=1, + le=90, + ) + + +class SearchProceduresArgs(BaseModel): + """Arguments for the 'search_procedures' tool.""" + + trigger: str = Field( + ..., + description="Trigger or situation to find procedures for", + min_length=1, + max_length=500, + ) + task_type: str | None = Field( + None, + description="Optional task type to filter procedures", + max_length=100, + ) + min_success_rate: float = Field( + 0.5, + description="Minimum success rate (0.0-1.0) for returned procedures", + ge=0.0, + le=1.0, + ) + limit: int = Field( + 5, + description="Maximum number of procedures to return", + ge=1, + le=20, + ) + include_steps: bool = Field( + True, + description="Whether to include detailed steps in the response", + ) + + +class RecordOutcomeArgs(BaseModel): + """Arguments for the 'record_outcome' tool.""" + + task_type: str = Field( + ..., + description="Type of task that was executed", + min_length=1, + max_length=100, + ) + outcome: OutcomeType = Field( + ..., + description="Outcome of the task execution", + ) + procedure_id: str | None = Field( + None, + description="ID of the procedure that was followed (if any)", + ) + context: dict[str, Any] = Field( + default_factory=dict, + description="Context in which the task was executed", + ) + lessons_learned: str | None = Field( + None, + description="What was learned from this execution", + max_length=2000, + ) + duration_seconds: float | None = Field( + None, + description="How long the task took to execute", + ge=0.0, + ) + error_details: str | None = Field( + None, + description="Details about any errors encountered (for failures)", + max_length=2000, + ) + + +# ============================================================================ +# Tool Definition Structure +# ============================================================================ + + +@dataclass +class MemoryToolDefinition: + """Definition of an MCP tool for the memory system.""" + + name: str + description: str + args_schema: type[BaseModel] + input_schema: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """Generate input schema from Pydantic model.""" + if not self.input_schema: + self.input_schema = self.args_schema.model_json_schema() + + def to_mcp_format(self) -> dict[str, Any]: + """Convert to MCP tool format.""" + return { + "name": self.name, + "description": self.description, + "inputSchema": self.input_schema, + } + + def validate_args(self, args: dict[str, Any]) -> BaseModel: + """Validate and parse arguments.""" + return self.args_schema.model_validate(args) + + +# ============================================================================ +# Tool Definitions +# ============================================================================ + + +REMEMBER_TOOL = MemoryToolDefinition( + name="remember", + description="""Store information in the agent's memory system. + +Use this tool to: +- Store temporary data in working memory (key-value with optional TTL) +- Record important events in episodic memory (automatically done on session end) +- Store facts/knowledge in semantic memory (subject-predicate-object triples) +- Save procedures in procedural memory (trigger conditions and steps) + +Examples: +- Working memory: {"memory_type": "working", "key": "current_task", "content": "Implementing auth", "ttl_seconds": 3600} +- Semantic fact: {"memory_type": "semantic", "subject": "User", "predicate": "prefers", "object_value": "dark mode", "content": "User preference noted"} +- Procedure: {"memory_type": "procedural", "trigger": "When creating a new file", "steps": [{"action": "check_exists"}, {"action": "create"}], "content": "File creation procedure"} +""", + args_schema=RememberArgs, +) + + +RECALL_TOOL = MemoryToolDefinition( + name="recall", + description="""Retrieve information from the agent's memory system. + +Use this tool to: +- Search for relevant past experiences (episodic) +- Look up known facts and knowledge (semantic) +- Find applicable procedures for current task (procedural) +- Get current session state (working) + +The query supports semantic search - describe what you're looking for in natural language. + +Examples: +- {"query": "How did I handle authentication errors before?", "memory_types": ["episodic"]} +- {"query": "What are the user's preferences?", "memory_types": ["semantic"], "limit": 5} +- {"query": "database connection", "memory_types": ["episodic", "semantic", "procedural"], "filters": {"outcome": "success"}} +""", + args_schema=RecallArgs, +) + + +FORGET_TOOL = MemoryToolDefinition( + name="forget", + description="""Remove information from the agent's memory system. + +Use this tool to: +- Clear temporary working memory entries +- Remove specific memories by ID +- Bulk remove memories matching a pattern (requires confirmation) + +WARNING: Deletion is permanent. Use with caution. + +Examples: +- Working memory: {"memory_type": "working", "key": "temp_calculation"} +- Specific memory: {"memory_type": "episodic", "memory_id": "ep-123"} +- Bulk (requires confirm): {"memory_type": "working", "pattern": "cache_*", "confirm_bulk": true} +""", + args_schema=ForgetArgs, +) + + +REFLECT_TOOL = MemoryToolDefinition( + name="reflect", + description="""Analyze patterns in the agent's memory to gain insights. + +Use this tool to: +- Identify patterns in recent work +- Understand what leads to success/failure +- Learn from past experiences +- Track learning progress over time + +Analysis types: +- recent_patterns: What patterns appear in recent work +- success_factors: What conditions lead to success +- failure_patterns: What causes failures and how to avoid them +- common_procedures: Most frequently used procedures +- learning_progress: How knowledge has grown over time + +Examples: +- {"analysis_type": "success_factors", "scope": "code_review", "depth": 3} +- {"analysis_type": "failure_patterns", "include_examples": true, "max_items": 5} +""", + args_schema=ReflectArgs, +) + + +GET_MEMORY_STATS_TOOL = MemoryToolDefinition( + name="get_memory_stats", + description="""Get statistics about the agent's memory usage. + +Returns information about: +- Total memories stored by type +- Storage utilization +- Recent activity summary +- Memory health indicators + +Use this to understand memory capacity and usage patterns. + +Examples: +- {"include_breakdown": true, "include_recent_activity": true} +- {"time_range_days": 30, "include_breakdown": true} +""", + args_schema=GetMemoryStatsArgs, +) + + +SEARCH_PROCEDURES_TOOL = MemoryToolDefinition( + name="search_procedures", + description="""Find relevant procedures for a given situation. + +Use this tool when you need to: +- Find the best way to handle a situation +- Look up proven approaches to problems +- Get step-by-step guidance for tasks + +Returns procedures ranked by relevance and success rate. + +Examples: +- {"trigger": "Deploying to production", "min_success_rate": 0.8} +- {"trigger": "Handling merge conflicts", "task_type": "git_operations", "limit": 3} +""", + args_schema=SearchProceduresArgs, +) + + +RECORD_OUTCOME_TOOL = MemoryToolDefinition( + name="record_outcome", + description="""Record the outcome of a task execution. + +Use this tool after completing a task to: +- Update procedure success/failure rates +- Store lessons learned for future reference +- Improve procedure recommendations + +This helps the memory system learn from experience. + +Examples: +- {"task_type": "code_review", "outcome": "success", "lessons_learned": "Breaking changes caught early"} +- {"task_type": "deployment", "outcome": "failure", "error_details": "Database migration timeout", "lessons_learned": "Need to test migrations locally first"} +""", + args_schema=RecordOutcomeArgs, +) + + +# All tool definitions in a dictionary for easy lookup +MEMORY_TOOL_DEFINITIONS: dict[str, MemoryToolDefinition] = { + "remember": REMEMBER_TOOL, + "recall": RECALL_TOOL, + "forget": FORGET_TOOL, + "reflect": REFLECT_TOOL, + "get_memory_stats": GET_MEMORY_STATS_TOOL, + "search_procedures": SEARCH_PROCEDURES_TOOL, + "record_outcome": RECORD_OUTCOME_TOOL, +} + + +def get_all_tool_schemas() -> list[dict[str, Any]]: + """Get MCP-formatted schemas for all memory tools.""" + return [tool.to_mcp_format() for tool in MEMORY_TOOL_DEFINITIONS.values()] + + +def get_tool_definition(name: str) -> MemoryToolDefinition | None: + """Get a specific tool definition by name.""" + return MEMORY_TOOL_DEFINITIONS.get(name) diff --git a/backend/tests/unit/services/memory/mcp/__init__.py b/backend/tests/unit/services/memory/mcp/__init__.py new file mode 100644 index 0000000..0d3059e --- /dev/null +++ b/backend/tests/unit/services/memory/mcp/__init__.py @@ -0,0 +1,2 @@ +# tests/unit/services/memory/mcp/__init__.py +"""Tests for memory MCP tools.""" diff --git a/backend/tests/unit/services/memory/mcp/test_service.py b/backend/tests/unit/services/memory/mcp/test_service.py new file mode 100644 index 0000000..f403262 --- /dev/null +++ b/backend/tests/unit/services/memory/mcp/test_service.py @@ -0,0 +1,651 @@ +# tests/unit/services/memory/mcp/test_service.py +"""Tests for MemoryToolService.""" + +from datetime import UTC, datetime +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import UUID, uuid4 + +import pytest + +from app.services.memory.mcp.service import ( + MemoryToolService, + ToolContext, + ToolResult, + get_memory_tool_service, +) +from app.services.memory.mcp.tools import ( + AnalysisType, + MemoryType, + OutcomeType, +) +from app.services.memory.types import Outcome + +pytestmark = pytest.mark.asyncio(loop_scope="function") + + +def make_context( + project_id: UUID | None = None, + agent_instance_id: UUID | None = None, + session_id: str | None = None, +) -> ToolContext: + """Create a test context.""" + return ToolContext( + project_id=project_id or uuid4(), + agent_instance_id=agent_instance_id or uuid4(), + session_id=session_id or "test-session", + ) + + +def make_mock_session() -> AsyncMock: + """Create a mock database session.""" + session = AsyncMock() + session.execute = AsyncMock() + session.commit = AsyncMock() + session.flush = AsyncMock() + return session + + +class TestToolContext: + """Tests for ToolContext dataclass.""" + + def test_context_creation(self) -> None: + """Context should be creatable with required fields.""" + project_id = uuid4() + ctx = ToolContext(project_id=project_id) + assert ctx.project_id == project_id + assert ctx.agent_instance_id is None + assert ctx.session_id is None + + def test_context_with_all_fields(self) -> None: + """Context should accept all optional fields.""" + project_id = uuid4() + agent_id = uuid4() + ctx = ToolContext( + project_id=project_id, + agent_instance_id=agent_id, + agent_type_id=uuid4(), + session_id="session-123", + user_id=uuid4(), + ) + assert ctx.project_id == project_id + assert ctx.agent_instance_id == agent_id + assert ctx.session_id == "session-123" + + +class TestToolResult: + """Tests for ToolResult dataclass.""" + + def test_success_result(self) -> None: + """Success result should have correct fields.""" + result = ToolResult( + success=True, + data={"key": "value"}, + execution_time_ms=10.5, + ) + assert result.success is True + assert result.data == {"key": "value"} + assert result.error is None + + def test_error_result(self) -> None: + """Error result should have correct fields.""" + result = ToolResult( + success=False, + error="Something went wrong", + error_code="VALIDATION_ERROR", + ) + assert result.success is False + assert result.error == "Something went wrong" + assert result.error_code == "VALIDATION_ERROR" + + def test_to_dict(self) -> None: + """Result should convert to dict correctly.""" + result = ToolResult( + success=True, + data={"test": 1}, + execution_time_ms=5.0, + ) + result_dict = result.to_dict() + assert result_dict["success"] is True + assert result_dict["data"] == {"test": 1} + assert result_dict["execution_time_ms"] == 5.0 + + +class TestMemoryToolService: + """Tests for MemoryToolService.""" + + @pytest.fixture + def mock_session(self) -> AsyncMock: + """Create a mock session.""" + return make_mock_session() + + @pytest.fixture + def service(self, mock_session: AsyncMock) -> MemoryToolService: + """Create a service with mock session.""" + return MemoryToolService(session=mock_session) + + @pytest.fixture + def context(self) -> ToolContext: + """Create a test context.""" + return make_context() + + async def test_execute_unknown_tool( + self, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Unknown tool should return error.""" + result = await service.execute_tool( + tool_name="unknown_tool", + arguments={}, + context=context, + ) + assert result.success is False + assert result.error_code == "UNKNOWN_TOOL" + + async def test_execute_with_invalid_args( + self, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Invalid arguments should return validation error.""" + result = await service.execute_tool( + tool_name="remember", + arguments={"memory_type": "invalid_type"}, + context=context, + ) + assert result.success is False + assert result.error_code == "VALIDATION_ERROR" + + @patch("app.services.memory.mcp.service.WorkingMemory") + async def test_remember_working_memory( + self, + mock_working_cls: MagicMock, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Remember should store in working memory.""" + # Setup mock + mock_working = AsyncMock() + mock_working.set = AsyncMock() + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + result = await service.execute_tool( + tool_name="remember", + arguments={ + "memory_type": "working", + "content": "Test content", + "key": "test_key", + "ttl_seconds": 3600, + }, + context=context, + ) + + assert result.success is True + assert result.data["stored"] is True + assert result.data["memory_type"] == "working" + assert result.data["key"] == "test_key" + + async def test_remember_episodic_memory( + self, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Remember should store in episodic memory.""" + with patch("app.services.memory.mcp.service.EpisodicMemory") as mock_episodic_cls: + # Setup mock + mock_episode = MagicMock() + mock_episode.id = uuid4() + + mock_episodic = AsyncMock() + mock_episodic.record_episode = AsyncMock(return_value=mock_episode) + mock_episodic_cls.create = AsyncMock(return_value=mock_episodic) + + result = await service.execute_tool( + tool_name="remember", + arguments={ + "memory_type": "episodic", + "content": "Important event happened", + "importance": 0.8, + }, + context=context, + ) + + assert result.success is True + assert result.data["stored"] is True + assert result.data["memory_type"] == "episodic" + assert "episode_id" in result.data + + async def test_remember_working_without_key( + self, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Working memory without key should fail.""" + result = await service.execute_tool( + tool_name="remember", + arguments={ + "memory_type": "working", + "content": "Test content", + }, + context=context, + ) + + assert result.success is False + assert "key is required" in result.error.lower() + + async def test_remember_working_without_session( + self, + service: MemoryToolService, + ) -> None: + """Working memory without session should fail.""" + context = ToolContext(project_id=uuid4(), session_id=None) + + result = await service.execute_tool( + tool_name="remember", + arguments={ + "memory_type": "working", + "content": "Test content", + "key": "test_key", + }, + context=context, + ) + + assert result.success is False + assert "session id is required" in result.error.lower() + + async def test_remember_semantic_memory( + self, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Remember should store facts in semantic memory.""" + with patch("app.services.memory.mcp.service.SemanticMemory") as mock_semantic_cls: + mock_fact = MagicMock() + mock_fact.id = uuid4() + + mock_semantic = AsyncMock() + mock_semantic.store_fact = AsyncMock(return_value=mock_fact) + mock_semantic_cls.create = AsyncMock(return_value=mock_semantic) + + result = await service.execute_tool( + tool_name="remember", + arguments={ + "memory_type": "semantic", + "content": "User prefers dark mode", + "subject": "User", + "predicate": "prefers", + "object_value": "dark mode", + }, + context=context, + ) + + assert result.success is True + assert result.data["memory_type"] == "semantic" + assert "fact_id" in result.data + assert "triple" in result.data + + async def test_remember_semantic_without_fields( + self, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Semantic memory without subject/predicate/object should fail.""" + result = await service.execute_tool( + tool_name="remember", + arguments={ + "memory_type": "semantic", + "content": "Some content", + "subject": "User", + # Missing predicate and object_value + }, + context=context, + ) + + assert result.success is False + assert "required" in result.error.lower() + + async def test_remember_procedural_memory( + self, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Remember should store procedures in procedural memory.""" + with patch("app.services.memory.mcp.service.ProceduralMemory") as mock_procedural_cls: + mock_procedure = MagicMock() + mock_procedure.id = uuid4() + + mock_procedural = AsyncMock() + mock_procedural.record_procedure = AsyncMock(return_value=mock_procedure) + mock_procedural_cls.create = AsyncMock(return_value=mock_procedural) + + result = await service.execute_tool( + tool_name="remember", + arguments={ + "memory_type": "procedural", + "content": "File creation procedure", + "trigger": "When creating a new file", + "steps": [ + {"action": "check_exists"}, + {"action": "create"}, + ], + }, + context=context, + ) + + assert result.success is True + assert result.data["memory_type"] == "procedural" + assert "procedure_id" in result.data + assert result.data["steps_count"] == 2 + + @patch("app.services.memory.mcp.service.EpisodicMemory") + @patch("app.services.memory.mcp.service.SemanticMemory") + async def test_recall_from_multiple_types( + self, + mock_semantic_cls: MagicMock, + mock_episodic_cls: MagicMock, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Recall should search across multiple memory types.""" + # Mock episodic + mock_episode = MagicMock() + mock_episode.id = uuid4() + mock_episode.task_description = "Test episode" + mock_episode.outcome = Outcome.SUCCESS + mock_episode.occurred_at = datetime.now(UTC) + mock_episode.importance_score = 0.9 + + mock_episodic = AsyncMock() + mock_episodic.search_similar = AsyncMock(return_value=[mock_episode]) + mock_episodic_cls.create = AsyncMock(return_value=mock_episodic) + + # Mock semantic + mock_fact = MagicMock() + mock_fact.id = uuid4() + mock_fact.subject = "User" + mock_fact.predicate = "prefers" + mock_fact.object = "dark mode" + mock_fact.confidence = 0.8 + + mock_semantic = AsyncMock() + mock_semantic.search_facts = AsyncMock(return_value=[mock_fact]) + mock_semantic_cls.create = AsyncMock(return_value=mock_semantic) + + result = await service.execute_tool( + tool_name="recall", + arguments={ + "query": "user preferences", + "memory_types": ["episodic", "semantic"], + "limit": 10, + }, + context=context, + ) + + assert result.success is True + assert result.data["total_results"] == 2 + assert len(result.data["results"]) == 2 + + @patch("app.services.memory.mcp.service.WorkingMemory") + async def test_forget_working_memory( + self, + mock_working_cls: MagicMock, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Forget should delete from working memory.""" + mock_working = AsyncMock() + mock_working.delete = AsyncMock(return_value=True) + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + result = await service.execute_tool( + tool_name="forget", + arguments={ + "memory_type": "working", + "key": "temp_key", + }, + context=context, + ) + + assert result.success is True + assert result.data["deleted"] is True + assert result.data["deleted_count"] == 1 + + async def test_forget_pattern_requires_confirm( + self, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Pattern deletion should require confirmation.""" + with patch("app.services.memory.mcp.service.WorkingMemory") as mock_working_cls: + mock_working = AsyncMock() + mock_working.list_keys = AsyncMock(return_value=["cache_1", "cache_2"]) + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + result = await service.execute_tool( + tool_name="forget", + arguments={ + "memory_type": "working", + "pattern": "cache_*", + "confirm_bulk": False, + }, + context=context, + ) + + assert result.success is False + assert "confirm_bulk" in result.error.lower() + + @patch("app.services.memory.mcp.service.EpisodicMemory") + async def test_reflect_recent_patterns( + self, + mock_episodic_cls: MagicMock, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Reflect should analyze recent patterns.""" + # Create mock episodes + mock_episodes = [] + for i in range(5): + ep = MagicMock() + ep.id = uuid4() + ep.task_type = "code_review" if i % 2 == 0 else "deployment" + ep.outcome = Outcome.SUCCESS if i < 3 else Outcome.FAILURE + ep.task_description = f"Episode {i}" + ep.lessons_learned = None + ep.occurred_at = datetime.now(UTC) + mock_episodes.append(ep) + + mock_episodic = AsyncMock() + mock_episodic.get_recent = AsyncMock(return_value=mock_episodes) + mock_episodic_cls.create = AsyncMock(return_value=mock_episodic) + + result = await service.execute_tool( + tool_name="reflect", + arguments={ + "analysis_type": "recent_patterns", + "depth": 3, + }, + context=context, + ) + + assert result.success is True + assert result.data["analysis_type"] == "recent_patterns" + assert result.data["total_episodes"] == 5 + assert "top_task_types" in result.data + assert "outcome_distribution" in result.data + + @patch("app.services.memory.mcp.service.EpisodicMemory") + async def test_reflect_success_factors( + self, + mock_episodic_cls: MagicMock, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Reflect should analyze success factors.""" + mock_episodes = [] + for i in range(10): + ep = MagicMock() + ep.id = uuid4() + ep.task_type = "code_review" + ep.outcome = Outcome.SUCCESS if i < 8 else Outcome.FAILURE + ep.task_description = f"Episode {i}" + ep.lessons_learned = "Learned something" if i < 3 else None + ep.occurred_at = datetime.now(UTC) + mock_episodes.append(ep) + + mock_episodic = AsyncMock() + mock_episodic.get_recent = AsyncMock(return_value=mock_episodes) + mock_episodic_cls.create = AsyncMock(return_value=mock_episodic) + + result = await service.execute_tool( + tool_name="reflect", + arguments={ + "analysis_type": "success_factors", + "include_examples": True, + }, + context=context, + ) + + assert result.success is True + assert result.data["analysis_type"] == "success_factors" + assert result.data["overall_success_rate"] == 0.8 + + @patch("app.services.memory.mcp.service.EpisodicMemory") + @patch("app.services.memory.mcp.service.SemanticMemory") + @patch("app.services.memory.mcp.service.ProceduralMemory") + @patch("app.services.memory.mcp.service.WorkingMemory") + async def test_get_memory_stats( + self, + mock_working_cls: MagicMock, + mock_procedural_cls: MagicMock, + mock_semantic_cls: MagicMock, + mock_episodic_cls: MagicMock, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Get memory stats should return statistics.""" + # Setup mocks + mock_working = AsyncMock() + mock_working.list_keys = AsyncMock(return_value=["key1", "key2"]) + mock_working_cls.for_session = AsyncMock(return_value=mock_working) + + mock_episodic = AsyncMock() + mock_episodic.get_recent = AsyncMock(return_value=[MagicMock() for _ in range(10)]) + mock_episodic_cls.create = AsyncMock(return_value=mock_episodic) + + mock_semantic = AsyncMock() + mock_semantic.search_facts = AsyncMock(return_value=[MagicMock() for _ in range(5)]) + mock_semantic_cls.create = AsyncMock(return_value=mock_semantic) + + mock_procedural = AsyncMock() + mock_procedural.find_matching = AsyncMock(return_value=[MagicMock() for _ in range(3)]) + mock_procedural_cls.create = AsyncMock(return_value=mock_procedural) + + result = await service.execute_tool( + tool_name="get_memory_stats", + arguments={ + "include_breakdown": True, + "include_recent_activity": False, + }, + context=context, + ) + + assert result.success is True + assert "breakdown" in result.data + breakdown = result.data["breakdown"] + assert breakdown["working"] == 2 + assert breakdown["episodic"] == 10 + assert breakdown["semantic"] == 5 + assert breakdown["procedural"] == 3 + assert breakdown["total"] == 20 + + @patch("app.services.memory.mcp.service.ProceduralMemory") + async def test_search_procedures( + self, + mock_procedural_cls: MagicMock, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Search procedures should return matching procedures.""" + mock_proc = MagicMock() + mock_proc.id = uuid4() + mock_proc.name = "Deployment procedure" + mock_proc.description = "How to deploy" + mock_proc.trigger = "When deploying" + mock_proc.success_rate = 0.9 + mock_proc.execution_count = 10 + mock_proc.steps = [{"action": "deploy"}] + + mock_procedural = AsyncMock() + mock_procedural.find_matching = AsyncMock(return_value=[mock_proc]) + mock_procedural_cls.create = AsyncMock(return_value=mock_procedural) + + result = await service.execute_tool( + tool_name="search_procedures", + arguments={ + "trigger": "Deploying to production", + "min_success_rate": 0.8, + "include_steps": True, + }, + context=context, + ) + + assert result.success is True + assert result.data["procedures_found"] == 1 + proc = result.data["procedures"][0] + assert proc["name"] == "Deployment procedure" + assert "steps" in proc + + async def test_record_outcome( + self, + service: MemoryToolService, + context: ToolContext, + ) -> None: + """Record outcome should store outcome and update procedure.""" + with ( + patch("app.services.memory.mcp.service.EpisodicMemory") as mock_episodic_cls, + patch("app.services.memory.mcp.service.ProceduralMemory") as mock_procedural_cls, + ): + mock_episode = MagicMock() + mock_episode.id = uuid4() + + mock_episodic = AsyncMock() + mock_episodic.record_episode = AsyncMock(return_value=mock_episode) + mock_episodic_cls.create = AsyncMock(return_value=mock_episodic) + + mock_procedural = AsyncMock() + mock_procedural.record_outcome = AsyncMock() + mock_procedural_cls.create = AsyncMock(return_value=mock_procedural) + + result = await service.execute_tool( + tool_name="record_outcome", + arguments={ + "task_type": "code_review", + "outcome": "success", + "lessons_learned": "Breaking changes caught early", + "duration_seconds": 120.5, + }, + context=context, + ) + + assert result.success is True + assert result.data["recorded"] is True + assert result.data["outcome"] == "success" + assert "episode_id" in result.data + + +class TestGetMemoryToolService: + """Tests for get_memory_tool_service factory.""" + + async def test_creates_service(self) -> None: + """Factory should create a service.""" + mock_session = make_mock_session() + service = await get_memory_tool_service(mock_session) + assert isinstance(service, MemoryToolService) + + async def test_accepts_embedding_generator(self) -> None: + """Factory should accept embedding generator.""" + mock_session = make_mock_session() + mock_generator = MagicMock() + service = await get_memory_tool_service(mock_session, mock_generator) + assert service._embedding_generator is mock_generator diff --git a/backend/tests/unit/services/memory/mcp/test_tools.py b/backend/tests/unit/services/memory/mcp/test_tools.py new file mode 100644 index 0000000..c35a344 --- /dev/null +++ b/backend/tests/unit/services/memory/mcp/test_tools.py @@ -0,0 +1,420 @@ +# tests/unit/services/memory/mcp/test_tools.py +"""Tests for MCP tool definitions.""" + +import pytest +from pydantic import ValidationError + +from app.services.memory.mcp.tools import ( + MEMORY_TOOL_DEFINITIONS, + AnalysisType, + ForgetArgs, + GetMemoryStatsArgs, + MemoryToolDefinition, + MemoryType, + OutcomeType, + RecallArgs, + RecordOutcomeArgs, + ReflectArgs, + RememberArgs, + SearchProceduresArgs, + get_all_tool_schemas, + get_tool_definition, +) + + +class TestMemoryType: + """Tests for MemoryType enum.""" + + def test_all_types_defined(self) -> None: + """All memory types should be defined.""" + assert MemoryType.WORKING == "working" + assert MemoryType.EPISODIC == "episodic" + assert MemoryType.SEMANTIC == "semantic" + assert MemoryType.PROCEDURAL == "procedural" + + def test_enum_values(self) -> None: + """Enum values should match strings.""" + assert MemoryType.WORKING.value == "working" + assert MemoryType("episodic") == MemoryType.EPISODIC + + +class TestAnalysisType: + """Tests for AnalysisType enum.""" + + def test_all_types_defined(self) -> None: + """All analysis types should be defined.""" + assert AnalysisType.RECENT_PATTERNS == "recent_patterns" + assert AnalysisType.SUCCESS_FACTORS == "success_factors" + assert AnalysisType.FAILURE_PATTERNS == "failure_patterns" + assert AnalysisType.COMMON_PROCEDURES == "common_procedures" + assert AnalysisType.LEARNING_PROGRESS == "learning_progress" + + +class TestOutcomeType: + """Tests for OutcomeType enum.""" + + def test_all_outcomes_defined(self) -> None: + """All outcome types should be defined.""" + assert OutcomeType.SUCCESS == "success" + assert OutcomeType.PARTIAL == "partial" + assert OutcomeType.FAILURE == "failure" + assert OutcomeType.ABANDONED == "abandoned" + + +class TestRememberArgs: + """Tests for RememberArgs validation.""" + + def test_valid_working_memory_args(self) -> None: + """Valid working memory args should parse.""" + args = RememberArgs( + memory_type=MemoryType.WORKING, + content="Test content", + key="test_key", + ttl_seconds=3600, + ) + assert args.memory_type == MemoryType.WORKING + assert args.key == "test_key" + assert args.ttl_seconds == 3600 + + def test_valid_semantic_args(self) -> None: + """Valid semantic memory args should parse.""" + args = RememberArgs( + memory_type=MemoryType.SEMANTIC, + content="User prefers dark mode", + subject="User", + predicate="prefers", + object_value="dark mode", + ) + assert args.subject == "User" + assert args.predicate == "prefers" + assert args.object_value == "dark mode" + + def test_valid_procedural_args(self) -> None: + """Valid procedural memory args should parse.""" + args = RememberArgs( + memory_type=MemoryType.PROCEDURAL, + content="File creation procedure", + trigger="When creating a new file", + steps=[{"action": "check_exists"}, {"action": "create"}], + ) + assert args.trigger == "When creating a new file" + assert len(args.steps) == 2 + + def test_importance_validation(self) -> None: + """Importance must be between 0 and 1.""" + args = RememberArgs( + memory_type=MemoryType.WORKING, + content="Test", + importance=0.8, + ) + assert args.importance == 0.8 + + with pytest.raises(ValidationError): + RememberArgs( + memory_type=MemoryType.WORKING, + content="Test", + importance=1.5, # Invalid + ) + + with pytest.raises(ValidationError): + RememberArgs( + memory_type=MemoryType.WORKING, + content="Test", + importance=-0.1, # Invalid + ) + + def test_content_required(self) -> None: + """Content is required.""" + with pytest.raises(ValidationError): + RememberArgs( + memory_type=MemoryType.WORKING, + content="", # Empty not allowed + ) + + def test_ttl_validation(self) -> None: + """TTL must be within bounds.""" + with pytest.raises(ValidationError): + RememberArgs( + memory_type=MemoryType.WORKING, + content="Test", + ttl_seconds=0, # Too low + ) + + with pytest.raises(ValidationError): + RememberArgs( + memory_type=MemoryType.WORKING, + content="Test", + ttl_seconds=86400 * 31, # Over 30 days + ) + + def test_default_values(self) -> None: + """Default values should be set correctly.""" + args = RememberArgs( + memory_type=MemoryType.WORKING, + content="Test", + ) + assert args.importance == 0.5 + assert args.ttl_seconds is None + assert args.metadata == {} + assert args.key is None + + +class TestRecallArgs: + """Tests for RecallArgs validation.""" + + def test_valid_args(self) -> None: + """Valid recall args should parse.""" + args = RecallArgs( + query="authentication errors", + memory_types=[MemoryType.EPISODIC, MemoryType.SEMANTIC], + limit=10, + ) + assert args.query == "authentication errors" + assert len(args.memory_types) == 2 + assert args.limit == 10 + + def test_default_memory_types(self) -> None: + """Default memory types should be episodic and semantic.""" + args = RecallArgs(query="test query") + assert MemoryType.EPISODIC in args.memory_types + assert MemoryType.SEMANTIC in args.memory_types + + def test_limit_validation(self) -> None: + """Limit must be between 1 and 100.""" + with pytest.raises(ValidationError): + RecallArgs(query="test", limit=0) + + with pytest.raises(ValidationError): + RecallArgs(query="test", limit=101) + + def test_min_relevance_validation(self) -> None: + """Min relevance must be between 0 and 1.""" + args = RecallArgs(query="test", min_relevance=0.5) + assert args.min_relevance == 0.5 + + with pytest.raises(ValidationError): + RecallArgs(query="test", min_relevance=1.5) + + +class TestForgetArgs: + """Tests for ForgetArgs validation.""" + + def test_valid_key_deletion(self) -> None: + """Valid key deletion args should parse.""" + args = ForgetArgs( + memory_type=MemoryType.WORKING, + key="temp_key", + ) + assert args.memory_type == MemoryType.WORKING + assert args.key == "temp_key" + + def test_valid_id_deletion(self) -> None: + """Valid ID deletion args should parse.""" + args = ForgetArgs( + memory_type=MemoryType.EPISODIC, + memory_id="12345678-1234-1234-1234-123456789012", + ) + assert args.memory_id is not None + + def test_pattern_deletion_requires_confirm(self) -> None: + """Pattern deletion should parse but service should validate confirm.""" + args = ForgetArgs( + memory_type=MemoryType.WORKING, + pattern="cache_*", + confirm_bulk=False, + ) + assert args.pattern == "cache_*" + assert args.confirm_bulk is False + + +class TestReflectArgs: + """Tests for ReflectArgs validation.""" + + def test_valid_args(self) -> None: + """Valid reflect args should parse.""" + args = ReflectArgs( + analysis_type=AnalysisType.SUCCESS_FACTORS, + depth=3, + ) + assert args.analysis_type == AnalysisType.SUCCESS_FACTORS + assert args.depth == 3 + + def test_depth_validation(self) -> None: + """Depth must be between 1 and 5.""" + with pytest.raises(ValidationError): + ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=0) + + with pytest.raises(ValidationError): + ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=6) + + def test_default_values(self) -> None: + """Default values should be set correctly.""" + args = ReflectArgs(analysis_type=AnalysisType.RECENT_PATTERNS) + assert args.depth == 3 + assert args.include_examples is True + assert args.max_items == 10 + + +class TestGetMemoryStatsArgs: + """Tests for GetMemoryStatsArgs validation.""" + + def test_valid_args(self) -> None: + """Valid args should parse.""" + args = GetMemoryStatsArgs( + include_breakdown=True, + include_recent_activity=True, + time_range_days=30, + ) + assert args.include_breakdown is True + assert args.time_range_days == 30 + + def test_time_range_validation(self) -> None: + """Time range must be between 1 and 90.""" + with pytest.raises(ValidationError): + GetMemoryStatsArgs(time_range_days=0) + + with pytest.raises(ValidationError): + GetMemoryStatsArgs(time_range_days=91) + + +class TestSearchProceduresArgs: + """Tests for SearchProceduresArgs validation.""" + + def test_valid_args(self) -> None: + """Valid args should parse.""" + args = SearchProceduresArgs( + trigger="Deploying to production", + min_success_rate=0.8, + limit=5, + ) + assert args.trigger == "Deploying to production" + assert args.min_success_rate == 0.8 + + def test_trigger_required(self) -> None: + """Trigger is required.""" + with pytest.raises(ValidationError): + SearchProceduresArgs(trigger="") + + def test_success_rate_validation(self) -> None: + """Success rate must be between 0 and 1.""" + with pytest.raises(ValidationError): + SearchProceduresArgs(trigger="test", min_success_rate=1.5) + + +class TestRecordOutcomeArgs: + """Tests for RecordOutcomeArgs validation.""" + + def test_valid_success_args(self) -> None: + """Valid success args should parse.""" + args = RecordOutcomeArgs( + task_type="code_review", + outcome=OutcomeType.SUCCESS, + lessons_learned="Breaking changes caught early", + ) + assert args.task_type == "code_review" + assert args.outcome == OutcomeType.SUCCESS + + def test_valid_failure_args(self) -> None: + """Valid failure args should parse.""" + args = RecordOutcomeArgs( + task_type="deployment", + outcome=OutcomeType.FAILURE, + error_details="Database migration timeout", + duration_seconds=120.5, + ) + assert args.outcome == OutcomeType.FAILURE + assert args.error_details is not None + + def test_task_type_required(self) -> None: + """Task type is required.""" + with pytest.raises(ValidationError): + RecordOutcomeArgs(task_type="", outcome=OutcomeType.SUCCESS) + + +class TestMemoryToolDefinition: + """Tests for MemoryToolDefinition class.""" + + def test_to_mcp_format(self) -> None: + """Tool should convert to MCP format.""" + tool = MemoryToolDefinition( + name="test_tool", + description="A test tool", + args_schema=RememberArgs, + ) + + mcp_format = tool.to_mcp_format() + + assert mcp_format["name"] == "test_tool" + assert mcp_format["description"] == "A test tool" + assert "inputSchema" in mcp_format + assert "properties" in mcp_format["inputSchema"] + + def test_validate_args(self) -> None: + """Tool should validate args using schema.""" + tool = MemoryToolDefinition( + name="remember", + description="Store in memory", + args_schema=RememberArgs, + ) + + # Valid args + validated = tool.validate_args({ + "memory_type": "working", + "content": "Test content", + }) + assert isinstance(validated, RememberArgs) + + # Invalid args + with pytest.raises(ValidationError): + tool.validate_args({"memory_type": "invalid"}) + + +class TestToolDefinitions: + """Tests for the tool definitions dictionary.""" + + def test_all_tools_defined(self) -> None: + """All expected tools should be defined.""" + expected_tools = [ + "remember", + "recall", + "forget", + "reflect", + "get_memory_stats", + "search_procedures", + "record_outcome", + ] + + for tool_name in expected_tools: + assert tool_name in MEMORY_TOOL_DEFINITIONS + assert isinstance(MEMORY_TOOL_DEFINITIONS[tool_name], MemoryToolDefinition) + + def test_get_tool_definition(self) -> None: + """get_tool_definition should return correct tool.""" + tool = get_tool_definition("remember") + assert tool is not None + assert tool.name == "remember" + + unknown = get_tool_definition("unknown_tool") + assert unknown is None + + def test_get_all_tool_schemas(self) -> None: + """get_all_tool_schemas should return MCP-formatted schemas.""" + schemas = get_all_tool_schemas() + + assert len(schemas) == 7 + for schema in schemas: + assert "name" in schema + assert "description" in schema + assert "inputSchema" in schema + + def test_tool_descriptions_not_empty(self) -> None: + """All tools should have descriptions.""" + for name, tool in MEMORY_TOOL_DEFINITIONS.items(): + assert tool.description, f"Tool {name} has empty description" + assert len(tool.description) > 50, f"Tool {name} description too short" + + def test_input_schemas_have_properties(self) -> None: + """All tool schemas should have properties defined.""" + for name, tool in MEMORY_TOOL_DEFINITIONS.items(): + schema = tool.to_mcp_format() + assert "properties" in schema["inputSchema"], f"Tool {name} missing properties"