forked from cardosofelipe/fast-next-template
feat(memory): implement MCP tools for agent memory operations (#96)
Add MCP-compatible tools that expose memory operations to agents: Tools implemented: - remember: Store data in working, episodic, semantic, or procedural memory - recall: Retrieve memories by query across multiple memory types - forget: Delete specific keys or bulk delete by pattern - reflect: Analyze patterns in recent episodes (success/failure factors) - get_memory_stats: Return usage statistics and breakdowns - search_procedures: Find procedures matching trigger patterns - record_outcome: Record task outcomes and update procedure success rates Key components: - tools.py: Pydantic schemas for tool argument validation with comprehensive field constraints (importance 0-1, TTL limits, limit ranges) - service.py: MemoryToolService coordinating memory type operations with proper scoping via ToolContext (project_id, agent_instance_id, session_id) - Lazy initialization of memory services (WorkingMemory, EpisodicMemory, SemanticMemory, ProceduralMemory) Test coverage: - 60 tests covering tool definitions, argument validation, and service execution paths - Mock-based tests for all memory type interactions 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -133,4 +133,6 @@ __all__ = [
|
||||
"get_default_settings",
|
||||
"get_memory_settings",
|
||||
"reset_memory_settings",
|
||||
# MCP Tools - lazy import to avoid circular dependencies
|
||||
# Import directly: from app.services.memory.mcp import MemoryToolService
|
||||
]
|
||||
|
||||
40
backend/app/services/memory/mcp/__init__.py
Normal file
40
backend/app/services/memory/mcp/__init__.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# app/services/memory/mcp/__init__.py
|
||||
"""
|
||||
MCP Tools for Agent Memory System.
|
||||
|
||||
Exposes memory operations as MCP-compatible tools that agents can invoke:
|
||||
- remember: Store data in memory
|
||||
- recall: Retrieve from memory
|
||||
- forget: Remove from memory
|
||||
- reflect: Analyze patterns
|
||||
- get_memory_stats: Usage statistics
|
||||
- search_procedures: Find relevant procedures
|
||||
- record_outcome: Record task success/failure
|
||||
"""
|
||||
|
||||
from .service import MemoryToolService, get_memory_tool_service
|
||||
from .tools import (
|
||||
MEMORY_TOOL_DEFINITIONS,
|
||||
ForgetArgs,
|
||||
GetMemoryStatsArgs,
|
||||
MemoryToolDefinition,
|
||||
RecallArgs,
|
||||
RecordOutcomeArgs,
|
||||
ReflectArgs,
|
||||
RememberArgs,
|
||||
SearchProceduresArgs,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MEMORY_TOOL_DEFINITIONS",
|
||||
"ForgetArgs",
|
||||
"GetMemoryStatsArgs",
|
||||
"MemoryToolDefinition",
|
||||
"MemoryToolService",
|
||||
"RecallArgs",
|
||||
"RecordOutcomeArgs",
|
||||
"ReflectArgs",
|
||||
"RememberArgs",
|
||||
"SearchProceduresArgs",
|
||||
"get_memory_tool_service",
|
||||
]
|
||||
1042
backend/app/services/memory/mcp/service.py
Normal file
1042
backend/app/services/memory/mcp/service.py
Normal file
File diff suppressed because it is too large
Load Diff
491
backend/app/services/memory/mcp/tools.py
Normal file
491
backend/app/services/memory/mcp/tools.py
Normal file
@@ -0,0 +1,491 @@
|
||||
# app/services/memory/mcp/tools.py
|
||||
"""
|
||||
MCP Tool Definitions for Agent Memory System.
|
||||
|
||||
Defines the schema and metadata for memory-related MCP tools.
|
||||
These tools are invoked by AI agents to interact with the memory system.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
"""Types of memory for storage operations."""
|
||||
|
||||
WORKING = "working"
|
||||
EPISODIC = "episodic"
|
||||
SEMANTIC = "semantic"
|
||||
PROCEDURAL = "procedural"
|
||||
|
||||
|
||||
class AnalysisType(str, Enum):
|
||||
"""Types of pattern analysis for the reflect tool."""
|
||||
|
||||
RECENT_PATTERNS = "recent_patterns"
|
||||
SUCCESS_FACTORS = "success_factors"
|
||||
FAILURE_PATTERNS = "failure_patterns"
|
||||
COMMON_PROCEDURES = "common_procedures"
|
||||
LEARNING_PROGRESS = "learning_progress"
|
||||
|
||||
|
||||
class OutcomeType(str, Enum):
|
||||
"""Outcome types for record_outcome tool."""
|
||||
|
||||
SUCCESS = "success"
|
||||
PARTIAL = "partial"
|
||||
FAILURE = "failure"
|
||||
ABANDONED = "abandoned"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Argument Schemas (Pydantic models for validation)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class RememberArgs(BaseModel):
|
||||
"""Arguments for the 'remember' tool."""
|
||||
|
||||
memory_type: MemoryType = Field(
|
||||
...,
|
||||
description="Type of memory to store in: working, episodic, semantic, or procedural",
|
||||
)
|
||||
content: str = Field(
|
||||
...,
|
||||
description="The content to remember. Can be text, facts, or procedure steps.",
|
||||
min_length=1,
|
||||
max_length=10000,
|
||||
)
|
||||
key: str | None = Field(
|
||||
None,
|
||||
description="Optional key for working memory entries. Required for working memory type.",
|
||||
max_length=256,
|
||||
)
|
||||
importance: float = Field(
|
||||
0.5,
|
||||
description="Importance score from 0.0 (low) to 1.0 (critical)",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
ttl_seconds: int | None = Field(
|
||||
None,
|
||||
description="Time-to-live in seconds for working memory. None for permanent storage.",
|
||||
ge=1,
|
||||
le=86400 * 30, # Max 30 days
|
||||
)
|
||||
metadata: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata to store with the memory",
|
||||
)
|
||||
# For semantic memory (facts)
|
||||
subject: str | None = Field(
|
||||
None,
|
||||
description="Subject of the fact (for semantic memory)",
|
||||
max_length=256,
|
||||
)
|
||||
predicate: str | None = Field(
|
||||
None,
|
||||
description="Predicate/relationship (for semantic memory)",
|
||||
max_length=256,
|
||||
)
|
||||
object_value: str | None = Field(
|
||||
None,
|
||||
description="Object of the fact (for semantic memory)",
|
||||
max_length=1000,
|
||||
)
|
||||
# For procedural memory
|
||||
trigger: str | None = Field(
|
||||
None,
|
||||
description="Trigger condition for the procedure (for procedural memory)",
|
||||
max_length=500,
|
||||
)
|
||||
steps: list[dict[str, Any]] | None = Field(
|
||||
None,
|
||||
description="Procedure steps as a list of action dictionaries",
|
||||
)
|
||||
|
||||
|
||||
class RecallArgs(BaseModel):
|
||||
"""Arguments for the 'recall' tool."""
|
||||
|
||||
query: str = Field(
|
||||
...,
|
||||
description="Search query to find relevant memories",
|
||||
min_length=1,
|
||||
max_length=1000,
|
||||
)
|
||||
memory_types: list[MemoryType] = Field(
|
||||
default_factory=lambda: [MemoryType.EPISODIC, MemoryType.SEMANTIC],
|
||||
description="Types of memory to search in",
|
||||
)
|
||||
limit: int = Field(
|
||||
10,
|
||||
description="Maximum number of results to return",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
min_relevance: float = Field(
|
||||
0.0,
|
||||
description="Minimum relevance score (0.0-1.0) for results",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
filters: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Additional filters (e.g., outcome, task_type, date range)",
|
||||
)
|
||||
include_context: bool = Field(
|
||||
True,
|
||||
description="Whether to include surrounding context in results",
|
||||
)
|
||||
|
||||
|
||||
class ForgetArgs(BaseModel):
|
||||
"""Arguments for the 'forget' tool."""
|
||||
|
||||
memory_type: MemoryType = Field(
|
||||
...,
|
||||
description="Type of memory to remove from",
|
||||
)
|
||||
key: str | None = Field(
|
||||
None,
|
||||
description="Key to remove (for working memory)",
|
||||
max_length=256,
|
||||
)
|
||||
memory_id: str | None = Field(
|
||||
None,
|
||||
description="Specific memory ID to remove (for episodic/semantic/procedural)",
|
||||
)
|
||||
pattern: str | None = Field(
|
||||
None,
|
||||
description="Pattern to match for bulk removal (use with caution)",
|
||||
max_length=500,
|
||||
)
|
||||
confirm_bulk: bool = Field(
|
||||
False,
|
||||
description="Must be True to confirm bulk deletion when using pattern",
|
||||
)
|
||||
|
||||
|
||||
class ReflectArgs(BaseModel):
|
||||
"""Arguments for the 'reflect' tool."""
|
||||
|
||||
analysis_type: AnalysisType = Field(
|
||||
...,
|
||||
description="Type of pattern analysis to perform",
|
||||
)
|
||||
scope: str | None = Field(
|
||||
None,
|
||||
description="Optional scope to limit analysis (e.g., task_type, time range)",
|
||||
max_length=500,
|
||||
)
|
||||
depth: int = Field(
|
||||
3,
|
||||
description="Depth of analysis (1=surface, 5=deep)",
|
||||
ge=1,
|
||||
le=5,
|
||||
)
|
||||
include_examples: bool = Field(
|
||||
True,
|
||||
description="Whether to include example memories in the analysis",
|
||||
)
|
||||
max_items: int = Field(
|
||||
10,
|
||||
description="Maximum number of patterns/examples to analyze",
|
||||
ge=1,
|
||||
le=50,
|
||||
)
|
||||
|
||||
|
||||
class GetMemoryStatsArgs(BaseModel):
|
||||
"""Arguments for the 'get_memory_stats' tool."""
|
||||
|
||||
include_breakdown: bool = Field(
|
||||
True,
|
||||
description="Include breakdown by memory type",
|
||||
)
|
||||
include_recent_activity: bool = Field(
|
||||
True,
|
||||
description="Include recent memory activity summary",
|
||||
)
|
||||
time_range_days: int = Field(
|
||||
7,
|
||||
description="Time range for activity analysis in days",
|
||||
ge=1,
|
||||
le=90,
|
||||
)
|
||||
|
||||
|
||||
class SearchProceduresArgs(BaseModel):
|
||||
"""Arguments for the 'search_procedures' tool."""
|
||||
|
||||
trigger: str = Field(
|
||||
...,
|
||||
description="Trigger or situation to find procedures for",
|
||||
min_length=1,
|
||||
max_length=500,
|
||||
)
|
||||
task_type: str | None = Field(
|
||||
None,
|
||||
description="Optional task type to filter procedures",
|
||||
max_length=100,
|
||||
)
|
||||
min_success_rate: float = Field(
|
||||
0.5,
|
||||
description="Minimum success rate (0.0-1.0) for returned procedures",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
limit: int = Field(
|
||||
5,
|
||||
description="Maximum number of procedures to return",
|
||||
ge=1,
|
||||
le=20,
|
||||
)
|
||||
include_steps: bool = Field(
|
||||
True,
|
||||
description="Whether to include detailed steps in the response",
|
||||
)
|
||||
|
||||
|
||||
class RecordOutcomeArgs(BaseModel):
|
||||
"""Arguments for the 'record_outcome' tool."""
|
||||
|
||||
task_type: str = Field(
|
||||
...,
|
||||
description="Type of task that was executed",
|
||||
min_length=1,
|
||||
max_length=100,
|
||||
)
|
||||
outcome: OutcomeType = Field(
|
||||
...,
|
||||
description="Outcome of the task execution",
|
||||
)
|
||||
procedure_id: str | None = Field(
|
||||
None,
|
||||
description="ID of the procedure that was followed (if any)",
|
||||
)
|
||||
context: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Context in which the task was executed",
|
||||
)
|
||||
lessons_learned: str | None = Field(
|
||||
None,
|
||||
description="What was learned from this execution",
|
||||
max_length=2000,
|
||||
)
|
||||
duration_seconds: float | None = Field(
|
||||
None,
|
||||
description="How long the task took to execute",
|
||||
ge=0.0,
|
||||
)
|
||||
error_details: str | None = Field(
|
||||
None,
|
||||
description="Details about any errors encountered (for failures)",
|
||||
max_length=2000,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Definition Structure
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryToolDefinition:
|
||||
"""Definition of an MCP tool for the memory system."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
args_schema: type[BaseModel]
|
||||
input_schema: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Generate input schema from Pydantic model."""
|
||||
if not self.input_schema:
|
||||
self.input_schema = self.args_schema.model_json_schema()
|
||||
|
||||
def to_mcp_format(self) -> dict[str, Any]:
|
||||
"""Convert to MCP tool format."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"inputSchema": self.input_schema,
|
||||
}
|
||||
|
||||
def validate_args(self, args: dict[str, Any]) -> BaseModel:
|
||||
"""Validate and parse arguments."""
|
||||
return self.args_schema.model_validate(args)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tool Definitions
|
||||
# ============================================================================
|
||||
|
||||
|
||||
REMEMBER_TOOL = MemoryToolDefinition(
|
||||
name="remember",
|
||||
description="""Store information in the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Store temporary data in working memory (key-value with optional TTL)
|
||||
- Record important events in episodic memory (automatically done on session end)
|
||||
- Store facts/knowledge in semantic memory (subject-predicate-object triples)
|
||||
- Save procedures in procedural memory (trigger conditions and steps)
|
||||
|
||||
Examples:
|
||||
- Working memory: {"memory_type": "working", "key": "current_task", "content": "Implementing auth", "ttl_seconds": 3600}
|
||||
- Semantic fact: {"memory_type": "semantic", "subject": "User", "predicate": "prefers", "object_value": "dark mode", "content": "User preference noted"}
|
||||
- Procedure: {"memory_type": "procedural", "trigger": "When creating a new file", "steps": [{"action": "check_exists"}, {"action": "create"}], "content": "File creation procedure"}
|
||||
""",
|
||||
args_schema=RememberArgs,
|
||||
)
|
||||
|
||||
|
||||
RECALL_TOOL = MemoryToolDefinition(
|
||||
name="recall",
|
||||
description="""Retrieve information from the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Search for relevant past experiences (episodic)
|
||||
- Look up known facts and knowledge (semantic)
|
||||
- Find applicable procedures for current task (procedural)
|
||||
- Get current session state (working)
|
||||
|
||||
The query supports semantic search - describe what you're looking for in natural language.
|
||||
|
||||
Examples:
|
||||
- {"query": "How did I handle authentication errors before?", "memory_types": ["episodic"]}
|
||||
- {"query": "What are the user's preferences?", "memory_types": ["semantic"], "limit": 5}
|
||||
- {"query": "database connection", "memory_types": ["episodic", "semantic", "procedural"], "filters": {"outcome": "success"}}
|
||||
""",
|
||||
args_schema=RecallArgs,
|
||||
)
|
||||
|
||||
|
||||
FORGET_TOOL = MemoryToolDefinition(
|
||||
name="forget",
|
||||
description="""Remove information from the agent's memory system.
|
||||
|
||||
Use this tool to:
|
||||
- Clear temporary working memory entries
|
||||
- Remove specific memories by ID
|
||||
- Bulk remove memories matching a pattern (requires confirmation)
|
||||
|
||||
WARNING: Deletion is permanent. Use with caution.
|
||||
|
||||
Examples:
|
||||
- Working memory: {"memory_type": "working", "key": "temp_calculation"}
|
||||
- Specific memory: {"memory_type": "episodic", "memory_id": "ep-123"}
|
||||
- Bulk (requires confirm): {"memory_type": "working", "pattern": "cache_*", "confirm_bulk": true}
|
||||
""",
|
||||
args_schema=ForgetArgs,
|
||||
)
|
||||
|
||||
|
||||
REFLECT_TOOL = MemoryToolDefinition(
|
||||
name="reflect",
|
||||
description="""Analyze patterns in the agent's memory to gain insights.
|
||||
|
||||
Use this tool to:
|
||||
- Identify patterns in recent work
|
||||
- Understand what leads to success/failure
|
||||
- Learn from past experiences
|
||||
- Track learning progress over time
|
||||
|
||||
Analysis types:
|
||||
- recent_patterns: What patterns appear in recent work
|
||||
- success_factors: What conditions lead to success
|
||||
- failure_patterns: What causes failures and how to avoid them
|
||||
- common_procedures: Most frequently used procedures
|
||||
- learning_progress: How knowledge has grown over time
|
||||
|
||||
Examples:
|
||||
- {"analysis_type": "success_factors", "scope": "code_review", "depth": 3}
|
||||
- {"analysis_type": "failure_patterns", "include_examples": true, "max_items": 5}
|
||||
""",
|
||||
args_schema=ReflectArgs,
|
||||
)
|
||||
|
||||
|
||||
GET_MEMORY_STATS_TOOL = MemoryToolDefinition(
|
||||
name="get_memory_stats",
|
||||
description="""Get statistics about the agent's memory usage.
|
||||
|
||||
Returns information about:
|
||||
- Total memories stored by type
|
||||
- Storage utilization
|
||||
- Recent activity summary
|
||||
- Memory health indicators
|
||||
|
||||
Use this to understand memory capacity and usage patterns.
|
||||
|
||||
Examples:
|
||||
- {"include_breakdown": true, "include_recent_activity": true}
|
||||
- {"time_range_days": 30, "include_breakdown": true}
|
||||
""",
|
||||
args_schema=GetMemoryStatsArgs,
|
||||
)
|
||||
|
||||
|
||||
SEARCH_PROCEDURES_TOOL = MemoryToolDefinition(
|
||||
name="search_procedures",
|
||||
description="""Find relevant procedures for a given situation.
|
||||
|
||||
Use this tool when you need to:
|
||||
- Find the best way to handle a situation
|
||||
- Look up proven approaches to problems
|
||||
- Get step-by-step guidance for tasks
|
||||
|
||||
Returns procedures ranked by relevance and success rate.
|
||||
|
||||
Examples:
|
||||
- {"trigger": "Deploying to production", "min_success_rate": 0.8}
|
||||
- {"trigger": "Handling merge conflicts", "task_type": "git_operations", "limit": 3}
|
||||
""",
|
||||
args_schema=SearchProceduresArgs,
|
||||
)
|
||||
|
||||
|
||||
RECORD_OUTCOME_TOOL = MemoryToolDefinition(
|
||||
name="record_outcome",
|
||||
description="""Record the outcome of a task execution.
|
||||
|
||||
Use this tool after completing a task to:
|
||||
- Update procedure success/failure rates
|
||||
- Store lessons learned for future reference
|
||||
- Improve procedure recommendations
|
||||
|
||||
This helps the memory system learn from experience.
|
||||
|
||||
Examples:
|
||||
- {"task_type": "code_review", "outcome": "success", "lessons_learned": "Breaking changes caught early"}
|
||||
- {"task_type": "deployment", "outcome": "failure", "error_details": "Database migration timeout", "lessons_learned": "Need to test migrations locally first"}
|
||||
""",
|
||||
args_schema=RecordOutcomeArgs,
|
||||
)
|
||||
|
||||
|
||||
# All tool definitions in a dictionary for easy lookup
|
||||
MEMORY_TOOL_DEFINITIONS: dict[str, MemoryToolDefinition] = {
|
||||
"remember": REMEMBER_TOOL,
|
||||
"recall": RECALL_TOOL,
|
||||
"forget": FORGET_TOOL,
|
||||
"reflect": REFLECT_TOOL,
|
||||
"get_memory_stats": GET_MEMORY_STATS_TOOL,
|
||||
"search_procedures": SEARCH_PROCEDURES_TOOL,
|
||||
"record_outcome": RECORD_OUTCOME_TOOL,
|
||||
}
|
||||
|
||||
|
||||
def get_all_tool_schemas() -> list[dict[str, Any]]:
|
||||
"""Get MCP-formatted schemas for all memory tools."""
|
||||
return [tool.to_mcp_format() for tool in MEMORY_TOOL_DEFINITIONS.values()]
|
||||
|
||||
|
||||
def get_tool_definition(name: str) -> MemoryToolDefinition | None:
|
||||
"""Get a specific tool definition by name."""
|
||||
return MEMORY_TOOL_DEFINITIONS.get(name)
|
||||
2
backend/tests/unit/services/memory/mcp/__init__.py
Normal file
2
backend/tests/unit/services/memory/mcp/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/mcp/__init__.py
|
||||
"""Tests for memory MCP tools."""
|
||||
651
backend/tests/unit/services/memory/mcp/test_service.py
Normal file
651
backend/tests/unit/services/memory/mcp/test_service.py
Normal file
@@ -0,0 +1,651 @@
|
||||
# tests/unit/services/memory/mcp/test_service.py
|
||||
"""Tests for MemoryToolService."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.mcp.service import (
|
||||
MemoryToolService,
|
||||
ToolContext,
|
||||
ToolResult,
|
||||
get_memory_tool_service,
|
||||
)
|
||||
from app.services.memory.mcp.tools import (
|
||||
AnalysisType,
|
||||
MemoryType,
|
||||
OutcomeType,
|
||||
)
|
||||
from app.services.memory.types import Outcome
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="function")
|
||||
|
||||
|
||||
def make_context(
|
||||
project_id: UUID | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> ToolContext:
|
||||
"""Create a test context."""
|
||||
return ToolContext(
|
||||
project_id=project_id or uuid4(),
|
||||
agent_instance_id=agent_instance_id or uuid4(),
|
||||
session_id=session_id or "test-session",
|
||||
)
|
||||
|
||||
|
||||
def make_mock_session() -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.execute = AsyncMock()
|
||||
session.commit = AsyncMock()
|
||||
session.flush = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
"""Tests for ToolContext dataclass."""
|
||||
|
||||
def test_context_creation(self) -> None:
|
||||
"""Context should be creatable with required fields."""
|
||||
project_id = uuid4()
|
||||
ctx = ToolContext(project_id=project_id)
|
||||
assert ctx.project_id == project_id
|
||||
assert ctx.agent_instance_id is None
|
||||
assert ctx.session_id is None
|
||||
|
||||
def test_context_with_all_fields(self) -> None:
|
||||
"""Context should accept all optional fields."""
|
||||
project_id = uuid4()
|
||||
agent_id = uuid4()
|
||||
ctx = ToolContext(
|
||||
project_id=project_id,
|
||||
agent_instance_id=agent_id,
|
||||
agent_type_id=uuid4(),
|
||||
session_id="session-123",
|
||||
user_id=uuid4(),
|
||||
)
|
||||
assert ctx.project_id == project_id
|
||||
assert ctx.agent_instance_id == agent_id
|
||||
assert ctx.session_id == "session-123"
|
||||
|
||||
|
||||
class TestToolResult:
|
||||
"""Tests for ToolResult dataclass."""
|
||||
|
||||
def test_success_result(self) -> None:
|
||||
"""Success result should have correct fields."""
|
||||
result = ToolResult(
|
||||
success=True,
|
||||
data={"key": "value"},
|
||||
execution_time_ms=10.5,
|
||||
)
|
||||
assert result.success is True
|
||||
assert result.data == {"key": "value"}
|
||||
assert result.error is None
|
||||
|
||||
def test_error_result(self) -> None:
|
||||
"""Error result should have correct fields."""
|
||||
result = ToolResult(
|
||||
success=False,
|
||||
error="Something went wrong",
|
||||
error_code="VALIDATION_ERROR",
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error == "Something went wrong"
|
||||
assert result.error_code == "VALIDATION_ERROR"
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Result should convert to dict correctly."""
|
||||
result = ToolResult(
|
||||
success=True,
|
||||
data={"test": 1},
|
||||
execution_time_ms=5.0,
|
||||
)
|
||||
result_dict = result.to_dict()
|
||||
assert result_dict["success"] is True
|
||||
assert result_dict["data"] == {"test": 1}
|
||||
assert result_dict["execution_time_ms"] == 5.0
|
||||
|
||||
|
||||
class TestMemoryToolService:
|
||||
"""Tests for MemoryToolService."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock session."""
|
||||
return make_mock_session()
|
||||
|
||||
@pytest.fixture
|
||||
def service(self, mock_session: AsyncMock) -> MemoryToolService:
|
||||
"""Create a service with mock session."""
|
||||
return MemoryToolService(session=mock_session)
|
||||
|
||||
@pytest.fixture
|
||||
def context(self) -> ToolContext:
|
||||
"""Create a test context."""
|
||||
return make_context()
|
||||
|
||||
async def test_execute_unknown_tool(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Unknown tool should return error."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="unknown_tool",
|
||||
arguments={},
|
||||
context=context,
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error_code == "UNKNOWN_TOOL"
|
||||
|
||||
async def test_execute_with_invalid_args(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Invalid arguments should return validation error."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={"memory_type": "invalid_type"},
|
||||
context=context,
|
||||
)
|
||||
assert result.success is False
|
||||
assert result.error_code == "VALIDATION_ERROR"
|
||||
|
||||
@patch("app.services.memory.mcp.service.WorkingMemory")
|
||||
async def test_remember_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store in working memory."""
|
||||
# Setup mock
|
||||
mock_working = AsyncMock()
|
||||
mock_working.set = AsyncMock()
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
"key": "test_key",
|
||||
"ttl_seconds": 3600,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["stored"] is True
|
||||
assert result.data["memory_type"] == "working"
|
||||
assert result.data["key"] == "test_key"
|
||||
|
||||
async def test_remember_episodic_memory(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store in episodic memory."""
|
||||
with patch("app.services.memory.mcp.service.EpisodicMemory") as mock_episodic_cls:
|
||||
# Setup mock
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "episodic",
|
||||
"content": "Important event happened",
|
||||
"importance": 0.8,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["stored"] is True
|
||||
assert result.data["memory_type"] == "episodic"
|
||||
assert "episode_id" in result.data
|
||||
|
||||
async def test_remember_working_without_key(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Working memory without key should fail."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "key is required" in result.error.lower()
|
||||
|
||||
async def test_remember_working_without_session(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
) -> None:
|
||||
"""Working memory without session should fail."""
|
||||
context = ToolContext(project_id=uuid4(), session_id=None)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
"key": "test_key",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "session id is required" in result.error.lower()
|
||||
|
||||
async def test_remember_semantic_memory(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store facts in semantic memory."""
|
||||
with patch("app.services.memory.mcp.service.SemanticMemory") as mock_semantic_cls:
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.id = uuid4()
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.store_fact = AsyncMock(return_value=mock_fact)
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "semantic",
|
||||
"content": "User prefers dark mode",
|
||||
"subject": "User",
|
||||
"predicate": "prefers",
|
||||
"object_value": "dark mode",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["memory_type"] == "semantic"
|
||||
assert "fact_id" in result.data
|
||||
assert "triple" in result.data
|
||||
|
||||
async def test_remember_semantic_without_fields(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Semantic memory without subject/predicate/object should fail."""
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "semantic",
|
||||
"content": "Some content",
|
||||
"subject": "User",
|
||||
# Missing predicate and object_value
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "required" in result.error.lower()
|
||||
|
||||
async def test_remember_procedural_memory(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Remember should store procedures in procedural memory."""
|
||||
with patch("app.services.memory.mcp.service.ProceduralMemory") as mock_procedural_cls:
|
||||
mock_procedure = MagicMock()
|
||||
mock_procedure.id = uuid4()
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.record_procedure = AsyncMock(return_value=mock_procedure)
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="remember",
|
||||
arguments={
|
||||
"memory_type": "procedural",
|
||||
"content": "File creation procedure",
|
||||
"trigger": "When creating a new file",
|
||||
"steps": [
|
||||
{"action": "check_exists"},
|
||||
{"action": "create"},
|
||||
],
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["memory_type"] == "procedural"
|
||||
assert "procedure_id" in result.data
|
||||
assert result.data["steps_count"] == 2
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
@patch("app.services.memory.mcp.service.SemanticMemory")
|
||||
async def test_recall_from_multiple_types(
|
||||
self,
|
||||
mock_semantic_cls: MagicMock,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Recall should search across multiple memory types."""
|
||||
# Mock episodic
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
mock_episode.task_description = "Test episode"
|
||||
mock_episode.outcome = Outcome.SUCCESS
|
||||
mock_episode.occurred_at = datetime.now(UTC)
|
||||
mock_episode.importance_score = 0.9
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.search_similar = AsyncMock(return_value=[mock_episode])
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
# Mock semantic
|
||||
mock_fact = MagicMock()
|
||||
mock_fact.id = uuid4()
|
||||
mock_fact.subject = "User"
|
||||
mock_fact.predicate = "prefers"
|
||||
mock_fact.object = "dark mode"
|
||||
mock_fact.confidence = 0.8
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.search_facts = AsyncMock(return_value=[mock_fact])
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="recall",
|
||||
arguments={
|
||||
"query": "user preferences",
|
||||
"memory_types": ["episodic", "semantic"],
|
||||
"limit": 10,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["total_results"] == 2
|
||||
assert len(result.data["results"]) == 2
|
||||
|
||||
@patch("app.services.memory.mcp.service.WorkingMemory")
|
||||
async def test_forget_working_memory(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Forget should delete from working memory."""
|
||||
mock_working = AsyncMock()
|
||||
mock_working.delete = AsyncMock(return_value=True)
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="forget",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"key": "temp_key",
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["deleted"] is True
|
||||
assert result.data["deleted_count"] == 1
|
||||
|
||||
async def test_forget_pattern_requires_confirm(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Pattern deletion should require confirmation."""
|
||||
with patch("app.services.memory.mcp.service.WorkingMemory") as mock_working_cls:
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["cache_1", "cache_2"])
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="forget",
|
||||
arguments={
|
||||
"memory_type": "working",
|
||||
"pattern": "cache_*",
|
||||
"confirm_bulk": False,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "confirm_bulk" in result.error.lower()
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
async def test_reflect_recent_patterns(
|
||||
self,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Reflect should analyze recent patterns."""
|
||||
# Create mock episodes
|
||||
mock_episodes = []
|
||||
for i in range(5):
|
||||
ep = MagicMock()
|
||||
ep.id = uuid4()
|
||||
ep.task_type = "code_review" if i % 2 == 0 else "deployment"
|
||||
ep.outcome = Outcome.SUCCESS if i < 3 else Outcome.FAILURE
|
||||
ep.task_description = f"Episode {i}"
|
||||
ep.lessons_learned = None
|
||||
ep.occurred_at = datetime.now(UTC)
|
||||
mock_episodes.append(ep)
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent = AsyncMock(return_value=mock_episodes)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="reflect",
|
||||
arguments={
|
||||
"analysis_type": "recent_patterns",
|
||||
"depth": 3,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["analysis_type"] == "recent_patterns"
|
||||
assert result.data["total_episodes"] == 5
|
||||
assert "top_task_types" in result.data
|
||||
assert "outcome_distribution" in result.data
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
async def test_reflect_success_factors(
|
||||
self,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Reflect should analyze success factors."""
|
||||
mock_episodes = []
|
||||
for i in range(10):
|
||||
ep = MagicMock()
|
||||
ep.id = uuid4()
|
||||
ep.task_type = "code_review"
|
||||
ep.outcome = Outcome.SUCCESS if i < 8 else Outcome.FAILURE
|
||||
ep.task_description = f"Episode {i}"
|
||||
ep.lessons_learned = "Learned something" if i < 3 else None
|
||||
ep.occurred_at = datetime.now(UTC)
|
||||
mock_episodes.append(ep)
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent = AsyncMock(return_value=mock_episodes)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="reflect",
|
||||
arguments={
|
||||
"analysis_type": "success_factors",
|
||||
"include_examples": True,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["analysis_type"] == "success_factors"
|
||||
assert result.data["overall_success_rate"] == 0.8
|
||||
|
||||
@patch("app.services.memory.mcp.service.EpisodicMemory")
|
||||
@patch("app.services.memory.mcp.service.SemanticMemory")
|
||||
@patch("app.services.memory.mcp.service.ProceduralMemory")
|
||||
@patch("app.services.memory.mcp.service.WorkingMemory")
|
||||
async def test_get_memory_stats(
|
||||
self,
|
||||
mock_working_cls: MagicMock,
|
||||
mock_procedural_cls: MagicMock,
|
||||
mock_semantic_cls: MagicMock,
|
||||
mock_episodic_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Get memory stats should return statistics."""
|
||||
# Setup mocks
|
||||
mock_working = AsyncMock()
|
||||
mock_working.list_keys = AsyncMock(return_value=["key1", "key2"])
|
||||
mock_working_cls.for_session = AsyncMock(return_value=mock_working)
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.get_recent = AsyncMock(return_value=[MagicMock() for _ in range(10)])
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
mock_semantic = AsyncMock()
|
||||
mock_semantic.search_facts = AsyncMock(return_value=[MagicMock() for _ in range(5)])
|
||||
mock_semantic_cls.create = AsyncMock(return_value=mock_semantic)
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching = AsyncMock(return_value=[MagicMock() for _ in range(3)])
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="get_memory_stats",
|
||||
arguments={
|
||||
"include_breakdown": True,
|
||||
"include_recent_activity": False,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert "breakdown" in result.data
|
||||
breakdown = result.data["breakdown"]
|
||||
assert breakdown["working"] == 2
|
||||
assert breakdown["episodic"] == 10
|
||||
assert breakdown["semantic"] == 5
|
||||
assert breakdown["procedural"] == 3
|
||||
assert breakdown["total"] == 20
|
||||
|
||||
@patch("app.services.memory.mcp.service.ProceduralMemory")
|
||||
async def test_search_procedures(
|
||||
self,
|
||||
mock_procedural_cls: MagicMock,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Search procedures should return matching procedures."""
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.id = uuid4()
|
||||
mock_proc.name = "Deployment procedure"
|
||||
mock_proc.description = "How to deploy"
|
||||
mock_proc.trigger = "When deploying"
|
||||
mock_proc.success_rate = 0.9
|
||||
mock_proc.execution_count = 10
|
||||
mock_proc.steps = [{"action": "deploy"}]
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.find_matching = AsyncMock(return_value=[mock_proc])
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="search_procedures",
|
||||
arguments={
|
||||
"trigger": "Deploying to production",
|
||||
"min_success_rate": 0.8,
|
||||
"include_steps": True,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["procedures_found"] == 1
|
||||
proc = result.data["procedures"][0]
|
||||
assert proc["name"] == "Deployment procedure"
|
||||
assert "steps" in proc
|
||||
|
||||
async def test_record_outcome(
|
||||
self,
|
||||
service: MemoryToolService,
|
||||
context: ToolContext,
|
||||
) -> None:
|
||||
"""Record outcome should store outcome and update procedure."""
|
||||
with (
|
||||
patch("app.services.memory.mcp.service.EpisodicMemory") as mock_episodic_cls,
|
||||
patch("app.services.memory.mcp.service.ProceduralMemory") as mock_procedural_cls,
|
||||
):
|
||||
mock_episode = MagicMock()
|
||||
mock_episode.id = uuid4()
|
||||
|
||||
mock_episodic = AsyncMock()
|
||||
mock_episodic.record_episode = AsyncMock(return_value=mock_episode)
|
||||
mock_episodic_cls.create = AsyncMock(return_value=mock_episodic)
|
||||
|
||||
mock_procedural = AsyncMock()
|
||||
mock_procedural.record_outcome = AsyncMock()
|
||||
mock_procedural_cls.create = AsyncMock(return_value=mock_procedural)
|
||||
|
||||
result = await service.execute_tool(
|
||||
tool_name="record_outcome",
|
||||
arguments={
|
||||
"task_type": "code_review",
|
||||
"outcome": "success",
|
||||
"lessons_learned": "Breaking changes caught early",
|
||||
"duration_seconds": 120.5,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.data["recorded"] is True
|
||||
assert result.data["outcome"] == "success"
|
||||
assert "episode_id" in result.data
|
||||
|
||||
|
||||
class TestGetMemoryToolService:
|
||||
"""Tests for get_memory_tool_service factory."""
|
||||
|
||||
async def test_creates_service(self) -> None:
|
||||
"""Factory should create a service."""
|
||||
mock_session = make_mock_session()
|
||||
service = await get_memory_tool_service(mock_session)
|
||||
assert isinstance(service, MemoryToolService)
|
||||
|
||||
async def test_accepts_embedding_generator(self) -> None:
|
||||
"""Factory should accept embedding generator."""
|
||||
mock_session = make_mock_session()
|
||||
mock_generator = MagicMock()
|
||||
service = await get_memory_tool_service(mock_session, mock_generator)
|
||||
assert service._embedding_generator is mock_generator
|
||||
420
backend/tests/unit/services/memory/mcp/test_tools.py
Normal file
420
backend/tests/unit/services/memory/mcp/test_tools.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# tests/unit/services/memory/mcp/test_tools.py
|
||||
"""Tests for MCP tool definitions."""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.services.memory.mcp.tools import (
|
||||
MEMORY_TOOL_DEFINITIONS,
|
||||
AnalysisType,
|
||||
ForgetArgs,
|
||||
GetMemoryStatsArgs,
|
||||
MemoryToolDefinition,
|
||||
MemoryType,
|
||||
OutcomeType,
|
||||
RecallArgs,
|
||||
RecordOutcomeArgs,
|
||||
ReflectArgs,
|
||||
RememberArgs,
|
||||
SearchProceduresArgs,
|
||||
get_all_tool_schemas,
|
||||
get_tool_definition,
|
||||
)
|
||||
|
||||
|
||||
class TestMemoryType:
|
||||
"""Tests for MemoryType enum."""
|
||||
|
||||
def test_all_types_defined(self) -> None:
|
||||
"""All memory types should be defined."""
|
||||
assert MemoryType.WORKING == "working"
|
||||
assert MemoryType.EPISODIC == "episodic"
|
||||
assert MemoryType.SEMANTIC == "semantic"
|
||||
assert MemoryType.PROCEDURAL == "procedural"
|
||||
|
||||
def test_enum_values(self) -> None:
|
||||
"""Enum values should match strings."""
|
||||
assert MemoryType.WORKING.value == "working"
|
||||
assert MemoryType("episodic") == MemoryType.EPISODIC
|
||||
|
||||
|
||||
class TestAnalysisType:
|
||||
"""Tests for AnalysisType enum."""
|
||||
|
||||
def test_all_types_defined(self) -> None:
|
||||
"""All analysis types should be defined."""
|
||||
assert AnalysisType.RECENT_PATTERNS == "recent_patterns"
|
||||
assert AnalysisType.SUCCESS_FACTORS == "success_factors"
|
||||
assert AnalysisType.FAILURE_PATTERNS == "failure_patterns"
|
||||
assert AnalysisType.COMMON_PROCEDURES == "common_procedures"
|
||||
assert AnalysisType.LEARNING_PROGRESS == "learning_progress"
|
||||
|
||||
|
||||
class TestOutcomeType:
|
||||
"""Tests for OutcomeType enum."""
|
||||
|
||||
def test_all_outcomes_defined(self) -> None:
|
||||
"""All outcome types should be defined."""
|
||||
assert OutcomeType.SUCCESS == "success"
|
||||
assert OutcomeType.PARTIAL == "partial"
|
||||
assert OutcomeType.FAILURE == "failure"
|
||||
assert OutcomeType.ABANDONED == "abandoned"
|
||||
|
||||
|
||||
class TestRememberArgs:
|
||||
"""Tests for RememberArgs validation."""
|
||||
|
||||
def test_valid_working_memory_args(self) -> None:
|
||||
"""Valid working memory args should parse."""
|
||||
args = RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test content",
|
||||
key="test_key",
|
||||
ttl_seconds=3600,
|
||||
)
|
||||
assert args.memory_type == MemoryType.WORKING
|
||||
assert args.key == "test_key"
|
||||
assert args.ttl_seconds == 3600
|
||||
|
||||
def test_valid_semantic_args(self) -> None:
|
||||
"""Valid semantic memory args should parse."""
|
||||
args = RememberArgs(
|
||||
memory_type=MemoryType.SEMANTIC,
|
||||
content="User prefers dark mode",
|
||||
subject="User",
|
||||
predicate="prefers",
|
||||
object_value="dark mode",
|
||||
)
|
||||
assert args.subject == "User"
|
||||
assert args.predicate == "prefers"
|
||||
assert args.object_value == "dark mode"
|
||||
|
||||
def test_valid_procedural_args(self) -> None:
|
||||
"""Valid procedural memory args should parse."""
|
||||
args = RememberArgs(
|
||||
memory_type=MemoryType.PROCEDURAL,
|
||||
content="File creation procedure",
|
||||
trigger="When creating a new file",
|
||||
steps=[{"action": "check_exists"}, {"action": "create"}],
|
||||
)
|
||||
assert args.trigger == "When creating a new file"
|
||||
assert len(args.steps) == 2
|
||||
|
||||
def test_importance_validation(self) -> None:
|
||||
"""Importance must be between 0 and 1."""
|
||||
args = RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
importance=0.8,
|
||||
)
|
||||
assert args.importance == 0.8
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
importance=1.5, # Invalid
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
importance=-0.1, # Invalid
|
||||
)
|
||||
|
||||
def test_content_required(self) -> None:
|
||||
"""Content is required."""
|
||||
with pytest.raises(ValidationError):
|
||||
RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="", # Empty not allowed
|
||||
)
|
||||
|
||||
def test_ttl_validation(self) -> None:
|
||||
"""TTL must be within bounds."""
|
||||
with pytest.raises(ValidationError):
|
||||
RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
ttl_seconds=0, # Too low
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
ttl_seconds=86400 * 31, # Over 30 days
|
||||
)
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default values should be set correctly."""
|
||||
args = RememberArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
content="Test",
|
||||
)
|
||||
assert args.importance == 0.5
|
||||
assert args.ttl_seconds is None
|
||||
assert args.metadata == {}
|
||||
assert args.key is None
|
||||
|
||||
|
||||
class TestRecallArgs:
|
||||
"""Tests for RecallArgs validation."""
|
||||
|
||||
def test_valid_args(self) -> None:
|
||||
"""Valid recall args should parse."""
|
||||
args = RecallArgs(
|
||||
query="authentication errors",
|
||||
memory_types=[MemoryType.EPISODIC, MemoryType.SEMANTIC],
|
||||
limit=10,
|
||||
)
|
||||
assert args.query == "authentication errors"
|
||||
assert len(args.memory_types) == 2
|
||||
assert args.limit == 10
|
||||
|
||||
def test_default_memory_types(self) -> None:
|
||||
"""Default memory types should be episodic and semantic."""
|
||||
args = RecallArgs(query="test query")
|
||||
assert MemoryType.EPISODIC in args.memory_types
|
||||
assert MemoryType.SEMANTIC in args.memory_types
|
||||
|
||||
def test_limit_validation(self) -> None:
|
||||
"""Limit must be between 1 and 100."""
|
||||
with pytest.raises(ValidationError):
|
||||
RecallArgs(query="test", limit=0)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RecallArgs(query="test", limit=101)
|
||||
|
||||
def test_min_relevance_validation(self) -> None:
|
||||
"""Min relevance must be between 0 and 1."""
|
||||
args = RecallArgs(query="test", min_relevance=0.5)
|
||||
assert args.min_relevance == 0.5
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RecallArgs(query="test", min_relevance=1.5)
|
||||
|
||||
|
||||
class TestForgetArgs:
|
||||
"""Tests for ForgetArgs validation."""
|
||||
|
||||
def test_valid_key_deletion(self) -> None:
|
||||
"""Valid key deletion args should parse."""
|
||||
args = ForgetArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
key="temp_key",
|
||||
)
|
||||
assert args.memory_type == MemoryType.WORKING
|
||||
assert args.key == "temp_key"
|
||||
|
||||
def test_valid_id_deletion(self) -> None:
|
||||
"""Valid ID deletion args should parse."""
|
||||
args = ForgetArgs(
|
||||
memory_type=MemoryType.EPISODIC,
|
||||
memory_id="12345678-1234-1234-1234-123456789012",
|
||||
)
|
||||
assert args.memory_id is not None
|
||||
|
||||
def test_pattern_deletion_requires_confirm(self) -> None:
|
||||
"""Pattern deletion should parse but service should validate confirm."""
|
||||
args = ForgetArgs(
|
||||
memory_type=MemoryType.WORKING,
|
||||
pattern="cache_*",
|
||||
confirm_bulk=False,
|
||||
)
|
||||
assert args.pattern == "cache_*"
|
||||
assert args.confirm_bulk is False
|
||||
|
||||
|
||||
class TestReflectArgs:
|
||||
"""Tests for ReflectArgs validation."""
|
||||
|
||||
def test_valid_args(self) -> None:
|
||||
"""Valid reflect args should parse."""
|
||||
args = ReflectArgs(
|
||||
analysis_type=AnalysisType.SUCCESS_FACTORS,
|
||||
depth=3,
|
||||
)
|
||||
assert args.analysis_type == AnalysisType.SUCCESS_FACTORS
|
||||
assert args.depth == 3
|
||||
|
||||
def test_depth_validation(self) -> None:
|
||||
"""Depth must be between 1 and 5."""
|
||||
with pytest.raises(ValidationError):
|
||||
ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=0)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
ReflectArgs(analysis_type=AnalysisType.SUCCESS_FACTORS, depth=6)
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Default values should be set correctly."""
|
||||
args = ReflectArgs(analysis_type=AnalysisType.RECENT_PATTERNS)
|
||||
assert args.depth == 3
|
||||
assert args.include_examples is True
|
||||
assert args.max_items == 10
|
||||
|
||||
|
||||
class TestGetMemoryStatsArgs:
|
||||
"""Tests for GetMemoryStatsArgs validation."""
|
||||
|
||||
def test_valid_args(self) -> None:
|
||||
"""Valid args should parse."""
|
||||
args = GetMemoryStatsArgs(
|
||||
include_breakdown=True,
|
||||
include_recent_activity=True,
|
||||
time_range_days=30,
|
||||
)
|
||||
assert args.include_breakdown is True
|
||||
assert args.time_range_days == 30
|
||||
|
||||
def test_time_range_validation(self) -> None:
|
||||
"""Time range must be between 1 and 90."""
|
||||
with pytest.raises(ValidationError):
|
||||
GetMemoryStatsArgs(time_range_days=0)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
GetMemoryStatsArgs(time_range_days=91)
|
||||
|
||||
|
||||
class TestSearchProceduresArgs:
|
||||
"""Tests for SearchProceduresArgs validation."""
|
||||
|
||||
def test_valid_args(self) -> None:
|
||||
"""Valid args should parse."""
|
||||
args = SearchProceduresArgs(
|
||||
trigger="Deploying to production",
|
||||
min_success_rate=0.8,
|
||||
limit=5,
|
||||
)
|
||||
assert args.trigger == "Deploying to production"
|
||||
assert args.min_success_rate == 0.8
|
||||
|
||||
def test_trigger_required(self) -> None:
|
||||
"""Trigger is required."""
|
||||
with pytest.raises(ValidationError):
|
||||
SearchProceduresArgs(trigger="")
|
||||
|
||||
def test_success_rate_validation(self) -> None:
|
||||
"""Success rate must be between 0 and 1."""
|
||||
with pytest.raises(ValidationError):
|
||||
SearchProceduresArgs(trigger="test", min_success_rate=1.5)
|
||||
|
||||
|
||||
class TestRecordOutcomeArgs:
|
||||
"""Tests for RecordOutcomeArgs validation."""
|
||||
|
||||
def test_valid_success_args(self) -> None:
|
||||
"""Valid success args should parse."""
|
||||
args = RecordOutcomeArgs(
|
||||
task_type="code_review",
|
||||
outcome=OutcomeType.SUCCESS,
|
||||
lessons_learned="Breaking changes caught early",
|
||||
)
|
||||
assert args.task_type == "code_review"
|
||||
assert args.outcome == OutcomeType.SUCCESS
|
||||
|
||||
def test_valid_failure_args(self) -> None:
|
||||
"""Valid failure args should parse."""
|
||||
args = RecordOutcomeArgs(
|
||||
task_type="deployment",
|
||||
outcome=OutcomeType.FAILURE,
|
||||
error_details="Database migration timeout",
|
||||
duration_seconds=120.5,
|
||||
)
|
||||
assert args.outcome == OutcomeType.FAILURE
|
||||
assert args.error_details is not None
|
||||
|
||||
def test_task_type_required(self) -> None:
|
||||
"""Task type is required."""
|
||||
with pytest.raises(ValidationError):
|
||||
RecordOutcomeArgs(task_type="", outcome=OutcomeType.SUCCESS)
|
||||
|
||||
|
||||
class TestMemoryToolDefinition:
|
||||
"""Tests for MemoryToolDefinition class."""
|
||||
|
||||
def test_to_mcp_format(self) -> None:
|
||||
"""Tool should convert to MCP format."""
|
||||
tool = MemoryToolDefinition(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
args_schema=RememberArgs,
|
||||
)
|
||||
|
||||
mcp_format = tool.to_mcp_format()
|
||||
|
||||
assert mcp_format["name"] == "test_tool"
|
||||
assert mcp_format["description"] == "A test tool"
|
||||
assert "inputSchema" in mcp_format
|
||||
assert "properties" in mcp_format["inputSchema"]
|
||||
|
||||
def test_validate_args(self) -> None:
|
||||
"""Tool should validate args using schema."""
|
||||
tool = MemoryToolDefinition(
|
||||
name="remember",
|
||||
description="Store in memory",
|
||||
args_schema=RememberArgs,
|
||||
)
|
||||
|
||||
# Valid args
|
||||
validated = tool.validate_args({
|
||||
"memory_type": "working",
|
||||
"content": "Test content",
|
||||
})
|
||||
assert isinstance(validated, RememberArgs)
|
||||
|
||||
# Invalid args
|
||||
with pytest.raises(ValidationError):
|
||||
tool.validate_args({"memory_type": "invalid"})
|
||||
|
||||
|
||||
class TestToolDefinitions:
|
||||
"""Tests for the tool definitions dictionary."""
|
||||
|
||||
def test_all_tools_defined(self) -> None:
|
||||
"""All expected tools should be defined."""
|
||||
expected_tools = [
|
||||
"remember",
|
||||
"recall",
|
||||
"forget",
|
||||
"reflect",
|
||||
"get_memory_stats",
|
||||
"search_procedures",
|
||||
"record_outcome",
|
||||
]
|
||||
|
||||
for tool_name in expected_tools:
|
||||
assert tool_name in MEMORY_TOOL_DEFINITIONS
|
||||
assert isinstance(MEMORY_TOOL_DEFINITIONS[tool_name], MemoryToolDefinition)
|
||||
|
||||
def test_get_tool_definition(self) -> None:
|
||||
"""get_tool_definition should return correct tool."""
|
||||
tool = get_tool_definition("remember")
|
||||
assert tool is not None
|
||||
assert tool.name == "remember"
|
||||
|
||||
unknown = get_tool_definition("unknown_tool")
|
||||
assert unknown is None
|
||||
|
||||
def test_get_all_tool_schemas(self) -> None:
|
||||
"""get_all_tool_schemas should return MCP-formatted schemas."""
|
||||
schemas = get_all_tool_schemas()
|
||||
|
||||
assert len(schemas) == 7
|
||||
for schema in schemas:
|
||||
assert "name" in schema
|
||||
assert "description" in schema
|
||||
assert "inputSchema" in schema
|
||||
|
||||
def test_tool_descriptions_not_empty(self) -> None:
|
||||
"""All tools should have descriptions."""
|
||||
for name, tool in MEMORY_TOOL_DEFINITIONS.items():
|
||||
assert tool.description, f"Tool {name} has empty description"
|
||||
assert len(tool.description) > 50, f"Tool {name} description too short"
|
||||
|
||||
def test_input_schemas_have_properties(self) -> None:
|
||||
"""All tool schemas should have properties defined."""
|
||||
for name, tool in MEMORY_TOOL_DEFINITIONS.items():
|
||||
schema = tool.to_mcp_format()
|
||||
assert "properties" in schema["inputSchema"], f"Tool {name} missing properties"
|
||||
Reference in New Issue
Block a user