feat(memory): implement MCP tools for agent memory operations (#96)

Add MCP-compatible tools that expose memory operations to agents:

Tools implemented:
- remember: Store data in working, episodic, semantic, or procedural memory
- recall: Retrieve memories by query across multiple memory types
- forget: Delete specific keys or bulk delete by pattern
- reflect: Analyze patterns in recent episodes (success/failure factors)
- get_memory_stats: Return usage statistics and breakdowns
- search_procedures: Find procedures matching trigger patterns
- record_outcome: Record task outcomes and update procedure success rates

Key components:
- tools.py: Pydantic schemas for tool argument validation with comprehensive
  field constraints (importance 0-1, TTL limits, limit ranges)
- service.py: MemoryToolService coordinating memory type operations with
  proper scoping via ToolContext (project_id, agent_instance_id, session_id)
- Lazy initialization of memory services (WorkingMemory, EpisodicMemory,
  SemanticMemory, ProceduralMemory)

Test coverage:
- 60 tests covering tool definitions, argument validation, and service
  execution paths
- Mock-based tests for all memory type interactions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 03:32:10 +01:00
parent 1670e05e0d
commit 0b24d4c6cc
7 changed files with 2648 additions and 0 deletions

View File

@@ -133,4 +133,6 @@ __all__ = [
"get_default_settings",
"get_memory_settings",
"reset_memory_settings",
# MCP Tools - lazy import to avoid circular dependencies
# Import directly: from app.services.memory.mcp import MemoryToolService
]

View File

@@ -0,0 +1,40 @@
# app/services/memory/mcp/__init__.py
"""
MCP Tools for Agent Memory System.
Exposes memory operations as MCP-compatible tools that agents can invoke:
- remember: Store data in memory
- recall: Retrieve from memory
- forget: Remove from memory
- reflect: Analyze patterns
- get_memory_stats: Usage statistics
- search_procedures: Find relevant procedures
- record_outcome: Record task success/failure
"""
from .service import MemoryToolService, get_memory_tool_service
from .tools import (
MEMORY_TOOL_DEFINITIONS,
ForgetArgs,
GetMemoryStatsArgs,
MemoryToolDefinition,
RecallArgs,
RecordOutcomeArgs,
ReflectArgs,
RememberArgs,
SearchProceduresArgs,
)
__all__ = [
"MEMORY_TOOL_DEFINITIONS",
"ForgetArgs",
"GetMemoryStatsArgs",
"MemoryToolDefinition",
"MemoryToolService",
"RecallArgs",
"RecordOutcomeArgs",
"ReflectArgs",
"RememberArgs",
"SearchProceduresArgs",
"get_memory_tool_service",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,491 @@
# app/services/memory/mcp/tools.py
"""
MCP Tool Definitions for Agent Memory System.
Defines the schema and metadata for memory-related MCP tools.
These tools are invoked by AI agents to interact with the memory system.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class MemoryType(str, Enum):
"""Types of memory for storage operations."""
WORKING = "working"
EPISODIC = "episodic"
SEMANTIC = "semantic"
PROCEDURAL = "procedural"
class AnalysisType(str, Enum):
"""Types of pattern analysis for the reflect tool."""
RECENT_PATTERNS = "recent_patterns"
SUCCESS_FACTORS = "success_factors"
FAILURE_PATTERNS = "failure_patterns"
COMMON_PROCEDURES = "common_procedures"
LEARNING_PROGRESS = "learning_progress"
class OutcomeType(str, Enum):
"""Outcome types for record_outcome tool."""
SUCCESS = "success"
PARTIAL = "partial"
FAILURE = "failure"
ABANDONED = "abandoned"
# ============================================================================
# Tool Argument Schemas (Pydantic models for validation)
# ============================================================================
class RememberArgs(BaseModel):
"""Arguments for the 'remember' tool."""
memory_type: MemoryType = Field(
...,
description="Type of memory to store in: working, episodic, semantic, or procedural",
)
content: str = Field(
...,
description="The content to remember. Can be text, facts, or procedure steps.",
min_length=1,
max_length=10000,
)
key: str | None = Field(
None,
description="Optional key for working memory entries. Required for working memory type.",
max_length=256,
)
importance: float = Field(
0.5,
description="Importance score from 0.0 (low) to 1.0 (critical)",
ge=0.0,
le=1.0,
)
ttl_seconds: int | None = Field(
None,
description="Time-to-live in seconds for working memory. None for permanent storage.",
ge=1,
le=86400 * 30, # Max 30 days
)
metadata: dict[str, Any] = Field(
default_factory=dict,
description="Additional metadata to store with the memory",
)
# For semantic memory (facts)
subject: str | None = Field(
None,
description="Subject of the fact (for semantic memory)",
max_length=256,
)
predicate: str | None = Field(
None,
description="Predicate/relationship (for semantic memory)",
max_length=256,
)
object_value: str | None = Field(
None,
description="Object of the fact (for semantic memory)",
max_length=1000,
)
# For procedural memory
trigger: str | None = Field(
None,
description="Trigger condition for the procedure (for procedural memory)",
max_length=500,
)
steps: list[dict[str, Any]] | None = Field(
None,
description="Procedure steps as a list of action dictionaries",
)
class RecallArgs(BaseModel):
"""Arguments for the 'recall' tool."""
query: str = Field(
...,
description="Search query to find relevant memories",
min_length=1,
max_length=1000,
)
memory_types: list[MemoryType] = Field(
default_factory=lambda: [MemoryType.EPISODIC, MemoryType.SEMANTIC],
description="Types of memory to search in",
)
limit: int = Field(
10,
description="Maximum number of results to return",
ge=1,
le=100,
)
min_relevance: float = Field(
0.0,
description="Minimum relevance score (0.0-1.0) for results",
ge=0.0,
le=1.0,
)
filters: dict[str, Any] = Field(
default_factory=dict,
description="Additional filters (e.g., outcome, task_type, date range)",
)
include_context: bool = Field(
True,
description="Whether to include surrounding context in results",
)
class ForgetArgs(BaseModel):
"""Arguments for the 'forget' tool."""
memory_type: MemoryType = Field(
...,
description="Type of memory to remove from",
)
key: str | None = Field(
None,
description="Key to remove (for working memory)",
max_length=256,
)
memory_id: str | None = Field(
None,
description="Specific memory ID to remove (for episodic/semantic/procedural)",
)
pattern: str | None = Field(
None,
description="Pattern to match for bulk removal (use with caution)",
max_length=500,
)
confirm_bulk: bool = Field(
False,
description="Must be True to confirm bulk deletion when using pattern",
)
class ReflectArgs(BaseModel):
"""Arguments for the 'reflect' tool."""
analysis_type: AnalysisType = Field(
...,
description="Type of pattern analysis to perform",
)
scope: str | None = Field(
None,
description="Optional scope to limit analysis (e.g., task_type, time range)",
max_length=500,
)
depth: int = Field(
3,
description="Depth of analysis (1=surface, 5=deep)",
ge=1,
le=5,
)
include_examples: bool = Field(
True,
description="Whether to include example memories in the analysis",
)
max_items: int = Field(
10,
description="Maximum number of patterns/examples to analyze",
ge=1,
le=50,
)
class GetMemoryStatsArgs(BaseModel):
"""Arguments for the 'get_memory_stats' tool."""
include_breakdown: bool = Field(
True,
description="Include breakdown by memory type",
)
include_recent_activity: bool = Field(
True,
description="Include recent memory activity summary",
)
time_range_days: int = Field(
7,
description="Time range for activity analysis in days",
ge=1,
le=90,
)
class SearchProceduresArgs(BaseModel):
"""Arguments for the 'search_procedures' tool."""
trigger: str = Field(
...,
description="Trigger or situation to find procedures for",
min_length=1,
max_length=500,
)
task_type: str | None = Field(
None,
description="Optional task type to filter procedures",
max_length=100,
)
min_success_rate: float = Field(
0.5,
description="Minimum success rate (0.0-1.0) for returned procedures",
ge=0.0,
le=1.0,
)
limit: int = Field(
5,
description="Maximum number of procedures to return",
ge=1,
le=20,
)
include_steps: bool = Field(
True,
description="Whether to include detailed steps in the response",
)
class RecordOutcomeArgs(BaseModel):
"""Arguments for the 'record_outcome' tool."""
task_type: str = Field(
...,
description="Type of task that was executed",
min_length=1,
max_length=100,
)
outcome: OutcomeType = Field(
...,
description="Outcome of the task execution",
)
procedure_id: str | None = Field(
None,
description="ID of the procedure that was followed (if any)",
)
context: dict[str, Any] = Field(
default_factory=dict,
description="Context in which the task was executed",
)
lessons_learned: str | None = Field(
None,
description="What was learned from this execution",
max_length=2000,
)
duration_seconds: float | None = Field(
None,
description="How long the task took to execute",
ge=0.0,
)
error_details: str | None = Field(
None,
description="Details about any errors encountered (for failures)",
max_length=2000,
)
# ============================================================================
# Tool Definition Structure
# ============================================================================
@dataclass
class MemoryToolDefinition:
"""Definition of an MCP tool for the memory system."""
name: str
description: str
args_schema: type[BaseModel]
input_schema: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Generate input schema from Pydantic model."""
if not self.input_schema:
self.input_schema = self.args_schema.model_json_schema()
def to_mcp_format(self) -> dict[str, Any]:
"""Convert to MCP tool format."""
return {
"name": self.name,
"description": self.description,
"inputSchema": self.input_schema,
}
def validate_args(self, args: dict[str, Any]) -> BaseModel:
"""Validate and parse arguments."""
return self.args_schema.model_validate(args)
# ============================================================================
# Tool Definitions
# ============================================================================
REMEMBER_TOOL = MemoryToolDefinition(
name="remember",
description="""Store information in the agent's memory system.
Use this tool to:
- Store temporary data in working memory (key-value with optional TTL)
- Record important events in episodic memory (automatically done on session end)
- Store facts/knowledge in semantic memory (subject-predicate-object triples)
- Save procedures in procedural memory (trigger conditions and steps)
Examples:
- Working memory: {"memory_type": "working", "key": "current_task", "content": "Implementing auth", "ttl_seconds": 3600}
- Semantic fact: {"memory_type": "semantic", "subject": "User", "predicate": "prefers", "object_value": "dark mode", "content": "User preference noted"}
- Procedure: {"memory_type": "procedural", "trigger": "When creating a new file", "steps": [{"action": "check_exists"}, {"action": "create"}], "content": "File creation procedure"}
""",
args_schema=RememberArgs,
)
RECALL_TOOL = MemoryToolDefinition(
name="recall",
description="""Retrieve information from the agent's memory system.
Use this tool to:
- Search for relevant past experiences (episodic)
- Look up known facts and knowledge (semantic)
- Find applicable procedures for current task (procedural)
- Get current session state (working)
The query supports semantic search - describe what you're looking for in natural language.
Examples:
- {"query": "How did I handle authentication errors before?", "memory_types": ["episodic"]}
- {"query": "What are the user's preferences?", "memory_types": ["semantic"], "limit": 5}
- {"query": "database connection", "memory_types": ["episodic", "semantic", "procedural"], "filters": {"outcome": "success"}}
""",
args_schema=RecallArgs,
)
FORGET_TOOL = MemoryToolDefinition(
name="forget",
description="""Remove information from the agent's memory system.
Use this tool to:
- Clear temporary working memory entries
- Remove specific memories by ID
- Bulk remove memories matching a pattern (requires confirmation)
WARNING: Deletion is permanent. Use with caution.
Examples:
- Working memory: {"memory_type": "working", "key": "temp_calculation"}
- Specific memory: {"memory_type": "episodic", "memory_id": "ep-123"}
- Bulk (requires confirm): {"memory_type": "working", "pattern": "cache_*", "confirm_bulk": true}
""",
args_schema=ForgetArgs,
)
REFLECT_TOOL = MemoryToolDefinition(
name="reflect",
description="""Analyze patterns in the agent's memory to gain insights.
Use this tool to:
- Identify patterns in recent work
- Understand what leads to success/failure
- Learn from past experiences
- Track learning progress over time
Analysis types:
- recent_patterns: What patterns appear in recent work
- success_factors: What conditions lead to success
- failure_patterns: What causes failures and how to avoid them
- common_procedures: Most frequently used procedures
- learning_progress: How knowledge has grown over time
Examples:
- {"analysis_type": "success_factors", "scope": "code_review", "depth": 3}
- {"analysis_type": "failure_patterns", "include_examples": true, "max_items": 5}
""",
args_schema=ReflectArgs,
)
GET_MEMORY_STATS_TOOL = MemoryToolDefinition(
name="get_memory_stats",
description="""Get statistics about the agent's memory usage.
Returns information about:
- Total memories stored by type
- Storage utilization
- Recent activity summary
- Memory health indicators
Use this to understand memory capacity and usage patterns.
Examples:
- {"include_breakdown": true, "include_recent_activity": true}
- {"time_range_days": 30, "include_breakdown": true}
""",
args_schema=GetMemoryStatsArgs,
)
SEARCH_PROCEDURES_TOOL = MemoryToolDefinition(
name="search_procedures",
description="""Find relevant procedures for a given situation.
Use this tool when you need to:
- Find the best way to handle a situation
- Look up proven approaches to problems
- Get step-by-step guidance for tasks
Returns procedures ranked by relevance and success rate.
Examples:
- {"trigger": "Deploying to production", "min_success_rate": 0.8}
- {"trigger": "Handling merge conflicts", "task_type": "git_operations", "limit": 3}
""",
args_schema=SearchProceduresArgs,
)
RECORD_OUTCOME_TOOL = MemoryToolDefinition(
name="record_outcome",
description="""Record the outcome of a task execution.
Use this tool after completing a task to:
- Update procedure success/failure rates
- Store lessons learned for future reference
- Improve procedure recommendations
This helps the memory system learn from experience.
Examples:
- {"task_type": "code_review", "outcome": "success", "lessons_learned": "Breaking changes caught early"}
- {"task_type": "deployment", "outcome": "failure", "error_details": "Database migration timeout", "lessons_learned": "Need to test migrations locally first"}
""",
args_schema=RecordOutcomeArgs,
)
# All tool definitions in a dictionary for easy lookup
MEMORY_TOOL_DEFINITIONS: dict[str, MemoryToolDefinition] = {
"remember": REMEMBER_TOOL,
"recall": RECALL_TOOL,
"forget": FORGET_TOOL,
"reflect": REFLECT_TOOL,
"get_memory_stats": GET_MEMORY_STATS_TOOL,
"search_procedures": SEARCH_PROCEDURES_TOOL,
"record_outcome": RECORD_OUTCOME_TOOL,
}
def get_all_tool_schemas() -> list[dict[str, Any]]:
"""Get MCP-formatted schemas for all memory tools."""
return [tool.to_mcp_format() for tool in MEMORY_TOOL_DEFINITIONS.values()]
def get_tool_definition(name: str) -> MemoryToolDefinition | None:
"""Get a specific tool definition by name."""
return MEMORY_TOOL_DEFINITIONS.get(name)

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/mcp/__init__.py
"""Tests for memory MCP tools."""

View File

@@ -0,0 +1,651 @@
# tests/unit/services/memory/mcp/test_service.py
"""Tests for MemoryToolService."""
from datetime import UTC, datetime
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID, uuid4
import pytest
from app.services.memory.mcp.service import (
MemoryToolService,
ToolContext,
ToolResult,
get_memory_tool_service,
)
from app.services.memory.mcp.tools import (
AnalysisType,
MemoryType,
OutcomeType,
)
from app.services.memory.types import Outcome
pytestmark = pytest.mark.asyncio(loop_scope="function")
def make_context(
project_id: UUID | None = None,
agent_instance_id: UUID | None = None,
session_id: str | None = None,
) -> ToolContext:
"""Create a test context."""
return ToolContext(
project_id=project_id or uuid4(),
agent_instance_id=agent_instance_id or uuid4(),
session_id=session_id or "test-session",
)
def make_mock_session() -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
session.execute = AsyncMock()
session.commit = AsyncMock()
session.flush = AsyncMock()
return session
class TestToolContext:
"""Tests for ToolContext dataclass."""
def test_context_creation(self) -> None:
"""Context should be creatable with required fields."""
project_id = uuid4()
ctx = ToolContext(project_id=project_id)
assert ctx.project_id == project_id
assert ctx.agent_instance_id is None
assert ctx.session_id is None
def test_context_with_all_fields(self) -> None:
"""Context should accept all optional fields."""
project_id = uuid4()
agent_id = uuid4()
ctx = ToolContext(
project_id=project_id,
agent_instance_id=agent_id,
agent_type_id=uuid4(),
session_id="session-123",
user_id=uuid4(),
)
assert ctx.project_id == project_id
assert ctx.agent_instance_id == agent_id
assert ctx.session_id == "session-123"
class TestToolResult:
"""Tests for ToolResult dataclass."""
def test_success_result(self) -> None:
"""Success result should have correct fields."""
result = ToolResult(
success=True,
data={"key": "value"},
execution_time_ms=10.5,
)
assert result.success is True
assert result.data == {"key": "value"}
assert result.error is None
def test_error_result(self) -> None:
"""Error result should have correct fields."""
result = ToolResult(
success=False,
error="Something went wrong",
error_code="VALIDATION_ERROR",
)
assert result.success is False
assert result.error == "Something went wrong"
assert result.error_code == "VALIDATION_ERROR"
def test_to_dict(self) -> None:
"""Result should convert to dict correctly."""
result = ToolResult(
success=True,
data={"test": 1},
execution_time_ms=5.0,
)
result_dict = result.to_dict()
assert result_dict["success"] is True
assert result_dict["data"] == {"test": 1}
assert result_dict["execution_time_ms"] == 5.0
class TestMemoryToolService:
"""Tests for MemoryToolService."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock session."""
return make_mock_session()
@pytest.fixture
def service(self, mock_session: AsyncMock) -> MemoryToolService:
"""Create a service with mock session."""
return MemoryToolService(session=mock_session)
@pytest.fixture
def context(self) -> ToolContext:
"""Create a test context."""
return make_context()
async def test_execute_unknown_tool(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Unknown tool should return error."""
result = await service.execute_tool(
tool_name="unknown_tool",
arguments={},
context=context,
)
assert result.success is False
assert result.error_code == "UNKNOWN_TOOL"
async def test_execute_with_invalid_args(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Invalid arguments should return validation error."""
result = await service.execute_tool(
tool_name="remember",
arguments={"memory_type": "invalid_type"},
context=context,
)
assert result.success is False
assert result.error_code == "VALIDATION_ERROR"
@patch("app.services.memory.mcp.service.WorkingMemory")
async def test_remember_working_memory(
self,
mock_working_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Remember should store in working memory."""
# Setup mock
mock_working = AsyncMock()
mock_working.set = AsyncMock()
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "working",
"content": "Test content",
"key": "test_key",
"ttl_seconds": 3600,
},
context=context,
)
assert result.success is True
assert result.data["stored"] is True
assert result.data["memory_type"] == "working"
assert result.data["key"] == "test_key"
async def test_remember_episodic_memory(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Remember should store in episodic memory."""
with patch("app.services.memory.mcp.service.EpisodicMemory") as mock_episodic_cls:
# Setup mock
mock_episode = MagicMock()
mock_episode.id = uuid4()
mock_episodic = AsyncMock()
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "episodic",
"content": "Important event happened",
"importance": 0.8,
},
context=context,
)
assert result.success is True
assert result.data["stored"] is True
assert result.data["memory_type"] == "episodic"
assert "episode_id" in result.data
async def test_remember_working_without_key(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Working memory without key should fail."""
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "working",
"content": "Test content",
},
context=context,
)
assert result.success is False
assert "key is required" in result.error.lower()
async def test_remember_working_without_session(
self,
service: MemoryToolService,
) -> None:
"""Working memory without session should fail."""
context = ToolContext(project_id=uuid4(), session_id=None)
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "working",
"content": "Test content",
"key": "test_key",
},
context=context,
)
assert result.success is False
assert "session id is required" in result.error.lower()
async def test_remember_semantic_memory(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Remember should store facts in semantic memory."""
with patch("app.services.memory.mcp.service.SemanticMemory") as mock_semantic_cls:
mock_fact = MagicMock()
mock_fact.id = uuid4()
mock_semantic = AsyncMock()
mock_semantic.store_fact = AsyncMock(return_value=mock_fact)
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "semantic",
"content": "User prefers dark mode",
"subject": "User",
"predicate": "prefers",
"object_value": "dark mode",
},
context=context,
)
assert result.success is True
assert result.data["memory_type"] == "semantic"
assert "fact_id" in result.data
assert "triple" in result.data
async def test_remember_semantic_without_fields(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Semantic memory without subject/predicate/object should fail."""
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "semantic",
"content": "Some content",
"subject": "User",
# Missing predicate and object_value
},
context=context,
)
assert result.success is False
assert "required" in result.error.lower()
async def test_remember_procedural_memory(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Remember should store procedures in procedural memory."""
with patch("app.services.memory.mcp.service.ProceduralMemory") as mock_procedural_cls:
mock_procedure = MagicMock()
mock_procedure.id = uuid4()
mock_procedural = AsyncMock()
mock_procedural.record_procedure = AsyncMock(return_value=mock_procedure)
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
result = await service.execute_tool(
tool_name="remember",
arguments={
"memory_type": "procedural",
"content": "File creation procedure",
"trigger": "When creating a new file",
"steps": [
{"action": "check_exists"},
{"action": "create"},
],
},
context=context,
)
assert result.success is True
assert result.data["memory_type"] == "procedural"
assert "procedure_id" in result.data
assert result.data["steps_count"] == 2
@patch("app.services.memory.mcp.service.EpisodicMemory")
@patch("app.services.memory.mcp.service.SemanticMemory")
async def test_recall_from_multiple_types(
self,
mock_semantic_cls: MagicMock,
mock_episodic_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Recall should search across multiple memory types."""
# Mock episodic
mock_episode = MagicMock()
mock_episode.id = uuid4()
mock_episode.task_description = "Test episode"
mock_episode.outcome = Outcome.SUCCESS
mock_episode.occurred_at = datetime.now(UTC)
mock_episode.importance_score = 0.9
mock_episodic = AsyncMock()
mock_episodic.search_similar = AsyncMock(return_value=[mock_episode])
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
# Mock semantic
mock_fact = MagicMock()
mock_fact.id = uuid4()
mock_fact.subject = "User"
mock_fact.predicate = "prefers"
mock_fact.object = "dark mode"
mock_fact.confidence = 0.8
mock_semantic = AsyncMock()
mock_semantic.search_facts = AsyncMock(return_value=[mock_fact])
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
result = await service.execute_tool(
tool_name="recall",
arguments={
"query": "user preferences",
"memory_types": ["episodic", "semantic"],
"limit": 10,
},
context=context,
)
assert result.success is True
assert result.data["total_results"] == 2
assert len(result.data["results"]) == 2
@patch("app.services.memory.mcp.service.WorkingMemory")
async def test_forget_working_memory(
self,
mock_working_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Forget should delete from working memory."""
mock_working = AsyncMock()
mock_working.delete = AsyncMock(return_value=True)
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await service.execute_tool(
tool_name="forget",
arguments={
"memory_type": "working",
"key": "temp_key",
},
context=context,
)
assert result.success is True
assert result.data["deleted"] is True
assert result.data["deleted_count"] == 1
async def test_forget_pattern_requires_confirm(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Pattern deletion should require confirmation."""
with patch("app.services.memory.mcp.service.WorkingMemory") as mock_working_cls:
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["cache_1", "cache_2"])
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
result = await service.execute_tool(
tool_name="forget",
arguments={
"memory_type": "working",
"pattern": "cache_*",
"confirm_bulk": False,
},
context=context,
)
assert result.success is False
assert "confirm_bulk" in result.error.lower()
@patch("app.services.memory.mcp.service.EpisodicMemory")
async def test_reflect_recent_patterns(
self,
mock_episodic_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Reflect should analyze recent patterns."""
# Create mock episodes
mock_episodes = []
for i in range(5):
ep = MagicMock()
ep.id = uuid4()
ep.task_type = "code_review" if i % 2 == 0 else "deployment"
ep.outcome = Outcome.SUCCESS if i < 3 else Outcome.FAILURE
ep.task_description = f"Episode {i}"
ep.lessons_learned = None
ep.occurred_at = datetime.now(UTC)
mock_episodes.append(ep)
mock_episodic = AsyncMock()
mock_episodic.get_recent = AsyncMock(return_value=mock_episodes)
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
result = await service.execute_tool(
tool_name="reflect",
arguments={
"analysis_type": "recent_patterns",
"depth": 3,
},
context=context,
)
assert result.success is True
assert result.data["analysis_type"] == "recent_patterns"
assert result.data["total_episodes"] == 5
assert "top_task_types" in result.data
assert "outcome_distribution" in result.data
@patch("app.services.memory.mcp.service.EpisodicMemory")
async def test_reflect_success_factors(
self,
mock_episodic_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Reflect should analyze success factors."""
mock_episodes = []
for i in range(10):
ep = MagicMock()
ep.id = uuid4()
ep.task_type = "code_review"
ep.outcome = Outcome.SUCCESS if i < 8 else Outcome.FAILURE
ep.task_description = f"Episode {i}"
ep.lessons_learned = "Learned something" if i < 3 else None
ep.occurred_at = datetime.now(UTC)
mock_episodes.append(ep)
mock_episodic = AsyncMock()
mock_episodic.get_recent = AsyncMock(return_value=mock_episodes)
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
result = await service.execute_tool(
tool_name="reflect",
arguments={
"analysis_type": "success_factors",
"include_examples": True,
},
context=context,
)
assert result.success is True
assert result.data["analysis_type"] == "success_factors"
assert result.data["overall_success_rate"] == 0.8
@patch("app.services.memory.mcp.service.EpisodicMemory")
@patch("app.services.memory.mcp.service.SemanticMemory")
@patch("app.services.memory.mcp.service.ProceduralMemory")
@patch("app.services.memory.mcp.service.WorkingMemory")
async def test_get_memory_stats(
self,
mock_working_cls: MagicMock,
mock_procedural_cls: MagicMock,
mock_semantic_cls: MagicMock,
mock_episodic_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Get memory stats should return statistics."""
# Setup mocks
mock_working = AsyncMock()
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
mock_episodic = AsyncMock()
mock_episodic.get_recent = AsyncMock(return_value=[MagicMock() for _ in range(10)])
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
mock_semantic = AsyncMock()
mock_semantic.search_facts = AsyncMock(return_value=[MagicMock() for _ in range(5)])
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
mock_procedural = AsyncMock()
mock_procedural.find_matching = AsyncMock(return_value=[MagicMock() for _ in range(3)])
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
result = await service.execute_tool(
tool_name="get_memory_stats",
arguments={
"include_breakdown": True,
"include_recent_activity": False,
},
context=context,
)
assert result.success is True
assert "breakdown" in result.data
breakdown = result.data["breakdown"]
assert breakdown["working"] == 2
assert breakdown["episodic"] == 10
assert breakdown["semantic"] == 5
assert breakdown["procedural"] == 3
assert breakdown["total"] == 20
@patch("app.services.memory.mcp.service.ProceduralMemory")
async def test_search_procedures(
self,
mock_procedural_cls: MagicMock,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Search procedures should return matching procedures."""
mock_proc = MagicMock()
mock_proc.id = uuid4()
mock_proc.name = "Deployment procedure"
mock_proc.description = "How to deploy"
mock_proc.trigger = "When deploying"
mock_proc.success_rate = 0.9
mock_proc.execution_count = 10
mock_proc.steps = [{"action": "deploy"}]
mock_procedural = AsyncMock()
mock_procedural.find_matching = AsyncMock(return_value=[mock_proc])
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
result = await service.execute_tool(
tool_name="search_procedures",
arguments={
"trigger": "Deploying to production",
"min_success_rate": 0.8,
"include_steps": True,
},
context=context,
)
assert result.success is True
assert result.data["procedures_found"] == 1
proc = result.data["procedures"][0]
assert proc["name"] == "Deployment procedure"
assert "steps" in proc
async def test_record_outcome(
self,
service: MemoryToolService,
context: ToolContext,
) -> None:
"""Record outcome should store outcome and update procedure."""
with (
patch("app.services.memory.mcp.service.EpisodicMemory") as mock_episodic_cls,
patch("app.services.memory.mcp.service.ProceduralMemory") as mock_procedural_cls,
):
mock_episode = MagicMock()
mock_episode.id = uuid4()
mock_episodic = AsyncMock()
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
mock_procedural = AsyncMock()
mock_procedural.record_outcome = AsyncMock()
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
result = await service.execute_tool(
tool_name="record_outcome",
arguments={
"task_type": "code_review",
"outcome": "success",
"lessons_learned": "Breaking changes caught early",
"duration_seconds": 120.5,
},
context=context,
)
assert result.success is True
assert result.data["recorded"] is True
assert result.data["outcome"] == "success"
assert "episode_id" in result.data
class TestGetMemoryToolService:
"""Tests for get_memory_tool_service factory."""
async def test_creates_service(self) -> None:
"""Factory should create a service."""
mock_session = make_mock_session()
service = await get_memory_tool_service(mock_session)
assert isinstance(service, MemoryToolService)
async def test_accepts_embedding_generator(self) -> None:
"""Factory should accept embedding generator."""
mock_session = make_mock_session()
mock_generator = MagicMock()
service = await get_memory_tool_service(mock_session, mock_generator)
assert service._embedding_generator is mock_generator

View File

@@ -0,0 +1,420 @@
# tests/unit/services/memory/mcp/test_tools.py
"""Tests for MCP tool definitions."""
import pytest
from pydantic import ValidationError
from app.services.memory.mcp.tools import (
MEMORY_TOOL_DEFINITIONS,
AnalysisType,
ForgetArgs,
GetMemoryStatsArgs,
MemoryToolDefinition,
MemoryType,
OutcomeType,
RecallArgs,
RecordOutcomeArgs,
ReflectArgs,
RememberArgs,
SearchProceduresArgs,
get_all_tool_schemas,
get_tool_definition,
)
class TestMemoryType:
"""Tests for MemoryType enum."""
def test_all_types_defined(self) -> None:
"""All memory types should be defined."""
assert MemoryType.WORKING == "working"
assert MemoryType.EPISODIC == "episodic"
assert MemoryType.SEMANTIC == "semantic"
assert MemoryType.PROCEDURAL == "procedural"
def test_enum_values(self) -> None:
"""Enum values should match strings."""
assert MemoryType.WORKING.value == "working"
assert MemoryType("episodic") == MemoryType.EPISODIC
class TestAnalysisType:
"""Tests for AnalysisType enum."""
def test_all_types_defined(self) -> None:
"""All analysis types should be defined."""
assert AnalysisType.RECENT_PATTERNS == "recent_patterns"
assert AnalysisType.SUCCESS_FACTORS == "success_factors"
assert AnalysisType.FAILURE_PATTERNS == "failure_patterns"
assert AnalysisType.COMMON_PROCEDURES == "common_procedures"
assert AnalysisType.LEARNING_PROGRESS == "learning_progress"
class TestOutcomeType:
"""Tests for OutcomeType enum."""
def test_all_outcomes_defined(self) -> None:
"""All outcome types should be defined."""
assert OutcomeType.SUCCESS == "success"
assert OutcomeType.PARTIAL == "partial"
assert OutcomeType.FAILURE == "failure"
assert OutcomeType.ABANDONED == "abandoned"
class TestRememberArgs:
"""Tests for RememberArgs validation."""
def test_valid_working_memory_args(self) -> None:
"""Valid working memory args should parse."""
args = RememberArgs(
memory_type=MemoryType.WORKING,
content="Test content",
key="test_key",
ttl_seconds=3600,
)
assert args.memory_type == MemoryType.WORKING
assert args.key == "test_key"
assert args.ttl_seconds == 3600
def test_valid_semantic_args(self) -> None:
"""Valid semantic memory args should parse."""
args = RememberArgs(
memory_type=MemoryType.SEMANTIC,
content="User prefers dark mode",
subject="User",
predicate="prefers",
object_value="dark mode",
)
assert args.subject == "User"
assert args.predicate == "prefers"
assert args.object_value == "dark mode"
def test_valid_procedural_args(self) -> None:
"""Valid procedural memory args should parse."""
args = RememberArgs(
memory_type=MemoryType.PROCEDURAL,
content="File creation procedure",
trigger="When creating a new file",
steps=[{"action": "check_exists"}, {"action": "create"}],
)
assert args.trigger == "When creating a new file"
assert len(args.steps) == 2
def test_importance_validation(self) -> None:
"""Importance must be between 0 and 1."""
args = RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
importance=0.8,
)
assert args.importance == 0.8
with pytest.raises(ValidationError):
RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
importance=1.5, # Invalid
)
with pytest.raises(ValidationError):
RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
importance=-0.1, # Invalid
)
def test_content_required(self) -> None:
"""Content is required."""
with pytest.raises(ValidationError):
RememberArgs(
memory_type=MemoryType.WORKING,
content="", # Empty not allowed
)
def test_ttl_validation(self) -> None:
"""TTL must be within bounds."""
with pytest.raises(ValidationError):
RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
ttl_seconds=0, # Too low
)
with pytest.raises(ValidationError):
RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
ttl_seconds=86400 * 31, # Over 30 days
)
def test_default_values(self) -> None:
"""Default values should be set correctly."""
args = RememberArgs(
memory_type=MemoryType.WORKING,
content="Test",
)
assert args.importance == 0.5
assert args.ttl_seconds is None
assert args.metadata == {}
assert args.key is None
class TestRecallArgs:
"""Tests for RecallArgs validation."""
def test_valid_args(self) -> None:
"""Valid recall args should parse."""
args = RecallArgs(
query="authentication errors",
memory_types=[MemoryType.EPISODIC, MemoryType.SEMANTIC],
limit=10,
)
assert args.query == "authentication errors"
assert len(args.memory_types) == 2
assert args.limit == 10
def test_default_memory_types(self) -> None:
"""Default memory types should be episodic and semantic."""
args = RecallArgs(query="test query")
assert MemoryType.EPISODIC in args.memory_types
assert MemoryType.SEMANTIC in args.memory_types
def test_limit_validation(self) -> None:
"""Limit must be between 1 and 100."""
with pytest.raises(ValidationError):
RecallArgs(query="test", limit=0)
with pytest.raises(ValidationError):
RecallArgs(query="test", limit=101)
def test_min_relevance_validation(self) -> None:
"""Min relevance must be between 0 and 1."""
args = RecallArgs(query="test", min_relevance=0.5)
assert args.min_relevance == 0.5
with pytest.raises(ValidationError):
RecallArgs(query="test", min_relevance=1.5)
class TestForgetArgs:
"""Tests for ForgetArgs validation."""
def test_valid_key_deletion(self) -> None:
"""Valid key deletion args should parse."""
args = ForgetArgs(
memory_type=MemoryType.WORKING,
key="temp_key",
)
assert args.memory_type == MemoryType.WORKING
assert args.key == "temp_key"
def test_valid_id_deletion(self) -> None:
"""Valid ID deletion args should parse."""
args = ForgetArgs(
memory_type=MemoryType.EPISODIC,
memory_id="12345678-1234-1234-1234-123456789012",
)
assert args.memory_id is not None
def test_pattern_deletion_requires_confirm(self) -> None:
"""Pattern deletion should parse but service should validate confirm."""
args = ForgetArgs(
memory_type=MemoryType.WORKING,
pattern="cache_*",
confirm_bulk=False,
)
assert args.pattern == "cache_*"
assert args.confirm_bulk is False
class TestReflectArgs:
"""Tests for ReflectArgs validation."""
def test_valid_args(self) -> None:
"""Valid reflect args should parse."""
args = ReflectArgs(
analysis_type=AnalysisType.SUCCESS_FACTORS,
depth=3,
)
assert args.analysis_type == AnalysisType.SUCCESS_FACTORS
assert args.depth == 3
def test_depth_validation(self) -> None:
"""Depth must be between 1 and 5."""
with pytest.raises(ValidationError):
ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=0)
with pytest.raises(ValidationError):
ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=6)
def test_default_values(self) -> None:
"""Default values should be set correctly."""
args = ReflectArgs(analysis_type=AnalysisType.RECENT_PATTERNS)
assert args.depth == 3
assert args.include_examples is True
assert args.max_items == 10
class TestGetMemoryStatsArgs:
"""Tests for GetMemoryStatsArgs validation."""
def test_valid_args(self) -> None:
"""Valid args should parse."""
args = GetMemoryStatsArgs(
include_breakdown=True,
include_recent_activity=True,
time_range_days=30,
)
assert args.include_breakdown is True
assert args.time_range_days == 30
def test_time_range_validation(self) -> None:
"""Time range must be between 1 and 90."""
with pytest.raises(ValidationError):
GetMemoryStatsArgs(time_range_days=0)
with pytest.raises(ValidationError):
GetMemoryStatsArgs(time_range_days=91)
class TestSearchProceduresArgs:
"""Tests for SearchProceduresArgs validation."""
def test_valid_args(self) -> None:
"""Valid args should parse."""
args = SearchProceduresArgs(
trigger="Deploying to production",
min_success_rate=0.8,
limit=5,
)
assert args.trigger == "Deploying to production"
assert args.min_success_rate == 0.8
def test_trigger_required(self) -> None:
"""Trigger is required."""
with pytest.raises(ValidationError):
SearchProceduresArgs(trigger="")
def test_success_rate_validation(self) -> None:
"""Success rate must be between 0 and 1."""
with pytest.raises(ValidationError):
SearchProceduresArgs(trigger="test", min_success_rate=1.5)
class TestRecordOutcomeArgs:
"""Tests for RecordOutcomeArgs validation."""
def test_valid_success_args(self) -> None:
"""Valid success args should parse."""
args = RecordOutcomeArgs(
task_type="code_review",
outcome=OutcomeType.SUCCESS,
lessons_learned="Breaking changes caught early",
)
assert args.task_type == "code_review"
assert args.outcome == OutcomeType.SUCCESS
def test_valid_failure_args(self) -> None:
"""Valid failure args should parse."""
args = RecordOutcomeArgs(
task_type="deployment",
outcome=OutcomeType.FAILURE,
error_details="Database migration timeout",
duration_seconds=120.5,
)
assert args.outcome == OutcomeType.FAILURE
assert args.error_details is not None
def test_task_type_required(self) -> None:
"""Task type is required."""
with pytest.raises(ValidationError):
RecordOutcomeArgs(task_type="", outcome=OutcomeType.SUCCESS)
class TestMemoryToolDefinition:
"""Tests for MemoryToolDefinition class."""
def test_to_mcp_format(self) -> None:
"""Tool should convert to MCP format."""
tool = MemoryToolDefinition(
name="test_tool",
description="A test tool",
args_schema=RememberArgs,
)
mcp_format = tool.to_mcp_format()
assert mcp_format["name"] == "test_tool"
assert mcp_format["description"] == "A test tool"
assert "inputSchema" in mcp_format
assert "properties" in mcp_format["inputSchema"]
def test_validate_args(self) -> None:
"""Tool should validate args using schema."""
tool = MemoryToolDefinition(
name="remember",
description="Store in memory",
args_schema=RememberArgs,
)
# Valid args
validated = tool.validate_args({
"memory_type": "working",
"content": "Test content",
})
assert isinstance(validated, RememberArgs)
# Invalid args
with pytest.raises(ValidationError):
tool.validate_args({"memory_type": "invalid"})
class TestToolDefinitions:
"""Tests for the tool definitions dictionary."""
def test_all_tools_defined(self) -> None:
"""All expected tools should be defined."""
expected_tools = [
"remember",
"recall",
"forget",
"reflect",
"get_memory_stats",
"search_procedures",
"record_outcome",
]
for tool_name in expected_tools:
assert tool_name in MEMORY_TOOL_DEFINITIONS
assert isinstance(MEMORY_TOOL_DEFINITIONS[tool_name], MemoryToolDefinition)
def test_get_tool_definition(self) -> None:
"""get_tool_definition should return correct tool."""
tool = get_tool_definition("remember")
assert tool is not None
assert tool.name == "remember"
unknown = get_tool_definition("unknown_tool")
assert unknown is None
def test_get_all_tool_schemas(self) -> None:
"""get_all_tool_schemas should return MCP-formatted schemas."""
schemas = get_all_tool_schemas()
assert len(schemas) == 7
for schema in schemas:
assert "name" in schema
assert "description" in schema
assert "inputSchema" in schema
def test_tool_descriptions_not_empty(self) -> None:
"""All tools should have descriptions."""
for name, tool in MEMORY_TOOL_DEFINITIONS.items():
assert tool.description, f"Tool {name} has empty description"
assert len(tool.description) > 50, f"Tool {name} description too short"
def test_input_schemas_have_properties(self) -> None:
"""All tool schemas should have properties defined."""
for name, tool in MEMORY_TOOL_DEFINITIONS.items():
schema = tool.to_mcp_format()
assert "properties" in schema["inputSchema"], f"Tool {name} missing properties"