2 Commits

Author SHA1 Message Date
Felipe Cardoso
192237e69b fix(memory): unify Outcome enum and add ABANDONED support
- Add ABANDONED value to core Outcome enum in types.py
- Replace duplicate OutcomeType class in mcp/tools.py with alias to Outcome
- Simplify mcp/service.py to use outcome directly (no more silent mapping)
- Add migration 0006 to extend PostgreSQL episode_outcome enum
- Add missing constraints to migration 0005 (ix_facts_unique_triple_global)

This fixes the semantic issue where ABANDONED outcomes were silently
converted to FAILURE, losing information about task abandonment.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 01:46:48 +01:00
Felipe Cardoso
3edce9cd26 fix(memory): address critical bugs from multi-agent review
Bug Fixes:
- Remove singleton pattern from consolidation/reflection services to
  prevent stale database session bugs (session is now passed per-request)
- Add LRU eviction to MemoryToolService._working dict (max 1000 sessions)
  to prevent unbounded memory growth
- Replace O(n) list.remove() with O(1) OrderedDict.move_to_end() in
  RetrievalCache for better performance under load
- Use deque with maxlen for metrics histograms to prevent unbounded
  memory growth (circular buffer with 10k max samples)
- Use full UUID for checkpoint IDs instead of 8-char prefix to avoid
  collision risk at scale (birthday paradox at ~50k checkpoints)

Test Updates:
- Update checkpoint test to expect 36-char UUID
- Update reflection singleton tests to expect new factory behavior
- Add reset_memory_reflection() no-op for backwards compatibility

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 18:55:32 +01:00
12 changed files with 161 additions and 96 deletions

View File

@@ -300,6 +300,14 @@ def upgrade() -> None:
unique=True,
postgresql_where=sa.text("project_id IS NOT NULL"),
)
# Unique constraint for global facts (project_id IS NULL)
op.create_index(
"ix_facts_unique_triple_global",
"facts",
["subject", "predicate", "object"],
unique=True,
postgresql_where=sa.text("project_id IS NULL"),
)
# =========================================================================
# Create procedures table
@@ -396,6 +404,11 @@ def upgrade() -> None:
"facts",
"confidence >= 0.0 AND confidence <= 1.0",
)
op.create_check_constraint(
"ck_facts_reinforcement_positive",
"facts",
"reinforcement_count >= 1",
)
# Procedure constraints
op.create_check_constraint(
@@ -476,11 +489,15 @@ def downgrade() -> None:
# Drop check constraints first
op.drop_constraint("ck_procedures_failure_positive", "procedures", type_="check")
op.drop_constraint("ck_procedures_success_positive", "procedures", type_="check")
op.drop_constraint("ck_facts_reinforcement_positive", "facts", type_="check")
op.drop_constraint("ck_facts_confidence_range", "facts", type_="check")
op.drop_constraint("ck_episodes_tokens_positive", "episodes", type_="check")
op.drop_constraint("ck_episodes_duration_positive", "episodes", type_="check")
op.drop_constraint("ck_episodes_importance_range", "episodes", type_="check")
# Drop unique indexes for global facts
op.drop_index("ix_facts_unique_triple_global", "facts")
# Drop tables in reverse order (dependencies first)
op.drop_table("memory_consolidation_log")
op.drop_table("procedures")

View File

@@ -0,0 +1,52 @@
"""Add ABANDONED to episode_outcome enum
Revision ID: 0006
Revises: 0005
Create Date: 2025-01-06
This migration adds the 'abandoned' value to the episode_outcome enum type.
This allows episodes to track when a task was abandoned (not completed,
but not necessarily a failure either - e.g., user cancelled, session timeout).
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0006"
down_revision: str | None = "0005"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add 'abandoned' value to episode_outcome enum."""
# PostgreSQL ALTER TYPE ADD VALUE is safe and non-blocking
op.execute("ALTER TYPE episode_outcome ADD VALUE IF NOT EXISTS 'abandoned'")
def downgrade() -> None:
"""Remove 'abandoned' from episode_outcome enum.
Note: PostgreSQL doesn't support removing values from enums directly.
This downgrade converts any 'abandoned' episodes to 'failure' and
recreates the enum without 'abandoned'.
"""
# Convert any abandoned episodes to failure first
op.execute("""
UPDATE episodes
SET outcome = 'failure'
WHERE outcome = 'abandoned'
""")
# Recreate the enum without abandoned
# This is complex in PostgreSQL - requires creating new type, updating columns, dropping old
op.execute("ALTER TYPE episode_outcome RENAME TO episode_outcome_old")
op.execute("CREATE TYPE episode_outcome AS ENUM ('success', 'failure', 'partial')")
op.execute("""
ALTER TABLE episodes
ALTER COLUMN outcome TYPE episode_outcome
USING outcome::text::episode_outcome
""")
op.execute("DROP TYPE episode_outcome_old")

View File

@@ -892,27 +892,22 @@ class MemoryConsolidationService:
return result
# Singleton instance
_consolidation_service: MemoryConsolidationService | None = None
# Factory function - no singleton to avoid stale session issues
async def get_consolidation_service(
session: AsyncSession,
config: ConsolidationConfig | None = None,
) -> MemoryConsolidationService:
"""
Get or create the memory consolidation service.
Create a memory consolidation service for the given session.
Note: This creates a new instance each time to avoid stale session issues.
The service is lightweight and safe to recreate per-request.
Args:
session: Database session
session: Database session (must be active)
config: Optional configuration
Returns:
MemoryConsolidationService instance
"""
global _consolidation_service
if _consolidation_service is None:
_consolidation_service = MemoryConsolidationService(
session=session, config=config
)
return _consolidation_service
return MemoryConsolidationService(session=session, config=config)

View File

@@ -13,6 +13,7 @@ Provides hybrid retrieval capabilities combining:
import hashlib
import logging
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any, TypeVar
@@ -243,7 +244,8 @@ class RetrievalCache:
"""
In-memory cache for retrieval results.
Supports TTL-based expiration and LRU eviction.
Supports TTL-based expiration and LRU eviction with O(1) operations.
Uses OrderedDict for efficient LRU tracking.
"""
def __init__(
@@ -258,10 +260,10 @@ class RetrievalCache:
max_entries: Maximum cache entries
default_ttl_seconds: Default TTL for entries
"""
self._cache: dict[str, CacheEntry] = {}
# OrderedDict maintains insertion order; we use move_to_end for O(1) LRU
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._max_entries = max_entries
self._default_ttl = default_ttl_seconds
self._access_order: list[str] = []
logger.info(
f"Initialized RetrievalCache with max_entries={max_entries}, "
f"ttl={default_ttl_seconds}s"
@@ -283,14 +285,10 @@ class RetrievalCache:
entry = self._cache[query_key]
if entry.is_expired():
del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return None
# Update access order (LRU)
if query_key in self._access_order:
self._access_order.remove(query_key)
self._access_order.append(query_key)
# Update access order (LRU) - O(1) with OrderedDict
self._cache.move_to_end(query_key)
logger.debug(f"Cache hit for {query_key}")
return entry.results
@@ -309,11 +307,9 @@ class RetrievalCache:
results: Results to cache
ttl_seconds: TTL for this entry (or default)
"""
# Evict if at capacity
while len(self._cache) >= self._max_entries and self._access_order:
oldest_key = self._access_order.pop(0)
if oldest_key in self._cache:
del self._cache[oldest_key]
# Evict oldest entries if at capacity - O(1) with popitem(last=False)
while len(self._cache) >= self._max_entries:
self._cache.popitem(last=False)
entry = CacheEntry(
results=results,
@@ -323,7 +319,6 @@ class RetrievalCache:
)
self._cache[query_key] = entry
self._access_order.append(query_key)
logger.debug(f"Cached {len(results)} results for {query_key}")
def invalidate(self, query_key: str) -> bool:
@@ -338,8 +333,6 @@ class RetrievalCache:
"""
if query_key in self._cache:
del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return True
return False
@@ -376,7 +369,6 @@ class RetrievalCache:
"""
count = len(self._cache)
self._cache.clear()
self._access_order.clear()
logger.info(f"Cleared {count} cache entries")
return count

View File

@@ -7,6 +7,7 @@ All tools are scoped to project/agent context for proper isolation.
"""
import logging
from collections import OrderedDict
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
@@ -83,6 +84,9 @@ class MemoryToolService:
This service coordinates between different memory types.
"""
# Maximum number of working memory sessions to cache (LRU eviction)
MAX_WORKING_SESSIONS = 1000
def __init__(
self,
session: AsyncSession,
@@ -98,8 +102,8 @@ class MemoryToolService:
self._session = session
self._embedding_generator = embedding_generator
# Lazy-initialized memory services
self._working: dict[str, WorkingMemory] = {} # keyed by session_id
# Lazy-initialized memory services with LRU eviction for working memory
self._working: OrderedDict[str, WorkingMemory] = OrderedDict()
self._episodic: EpisodicMemory | None = None
self._semantic: SemanticMemory | None = None
self._procedural: ProceduralMemory | None = None
@@ -110,14 +114,28 @@ class MemoryToolService:
project_id: UUID | None = None,
agent_instance_id: UUID | None = None,
) -> WorkingMemory:
"""Get or create working memory for a session."""
if session_id not in self._working:
self._working[session_id] = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id) if project_id else None,
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
return self._working[session_id]
"""Get or create working memory for a session with LRU eviction."""
if session_id in self._working:
# Move to end (most recently used)
self._working.move_to_end(session_id)
return self._working[session_id]
# Evict oldest entries if at capacity
while len(self._working) >= self.MAX_WORKING_SESSIONS:
oldest_id, oldest_memory = self._working.popitem(last=False)
try:
await oldest_memory.close()
except Exception as e:
logger.warning(f"Error closing evicted working memory {oldest_id}: {e}")
# Create new working memory
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id) if project_id else None,
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
self._working[session_id] = working
return working
async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service."""
@@ -1006,15 +1024,8 @@ class MemoryToolService:
context: ToolContext,
) -> dict[str, Any]:
"""Execute the 'record_outcome' tool."""
# Map outcome type to memory Outcome
# Note: ABANDONED maps to FAILURE since core Outcome doesn't have ABANDONED
outcome_map = {
OutcomeType.SUCCESS: Outcome.SUCCESS,
OutcomeType.PARTIAL: Outcome.PARTIAL,
OutcomeType.FAILURE: Outcome.FAILURE,
OutcomeType.ABANDONED: Outcome.FAILURE, # No ABANDONED in core enum
}
outcome = outcome_map.get(args.outcome, Outcome.FAILURE)
# OutcomeType is now an alias for Outcome, use directly
outcome = args.outcome
# Record in episodic memory
episodic = await self._get_episodic()

View File

@@ -12,6 +12,9 @@ from typing import Any
from pydantic import BaseModel, Field
# OutcomeType alias - uses core Outcome enum from types module for consistency
from app.services.memory.types import Outcome as OutcomeType
class MemoryType(str, Enum):
"""Types of memory for storage operations."""
@@ -32,15 +35,6 @@ class AnalysisType(str, Enum):
LEARNING_PROGRESS = "learning_progress"
class OutcomeType(str, Enum):
"""Outcome types for record_outcome tool."""
SUCCESS = "success"
PARTIAL = "partial"
FAILURE = "failure"
ABANDONED = "abandoned"
# ============================================================================
# Tool Argument Schemas (Pydantic models for validation)
# ============================================================================

View File

@@ -7,7 +7,7 @@ Collects and exposes metrics for the memory system.
import asyncio
import logging
from collections import Counter, defaultdict
from collections import Counter, defaultdict, deque
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
@@ -57,11 +57,17 @@ class MemoryMetrics:
- Embedding operations
"""
# Maximum samples to keep in histogram (circular buffer)
MAX_HISTOGRAM_SAMPLES = 10000
def __init__(self) -> None:
"""Initialize MemoryMetrics."""
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
self._histograms: dict[str, list[float]] = defaultdict(list)
# Use deque with maxlen for bounded memory (circular buffer)
self._histograms: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=self.MAX_HISTOGRAM_SAMPLES)
)
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
self._lock = asyncio.Lock()

View File

@@ -7,7 +7,6 @@ Implements pattern detection, success/failure analysis, anomaly detection,
and insight generation.
"""
import asyncio
import logging
import statistics
from collections import Counter, defaultdict
@@ -1426,36 +1425,27 @@ class MemoryReflection:
)
# Singleton instance with async-safe initialization
_memory_reflection: MemoryReflection | None = None
_reflection_lock = asyncio.Lock()
# Factory function - no singleton to avoid stale session issues
async def get_memory_reflection(
session: AsyncSession,
config: ReflectionConfig | None = None,
) -> MemoryReflection:
"""
Get or create the memory reflection service (async-safe).
Create a memory reflection service for the given session.
Note: This creates a new instance each time to avoid stale session issues.
The service is lightweight and safe to recreate per-request.
Args:
session: Database session
session: Database session (must be active)
config: Optional configuration
Returns:
MemoryReflection instance
"""
global _memory_reflection
if _memory_reflection is None:
async with _reflection_lock:
# Double-check locking pattern
if _memory_reflection is None:
_memory_reflection = MemoryReflection(session=session, config=config)
return _memory_reflection
return MemoryReflection(session=session, config=config)
async def reset_memory_reflection() -> None:
"""Reset the memory reflection singleton (async-safe)."""
global _memory_reflection
async with _reflection_lock:
_memory_reflection = None
"""No-op for backwards compatibility (singleton pattern removed)."""
return

View File

@@ -42,6 +42,7 @@ class Outcome(str, Enum):
SUCCESS = "success"
FAILURE = "failure"
PARTIAL = "partial"
ABANDONED = "abandoned"
class ConsolidationStatus(str, Enum):

View File

@@ -423,7 +423,8 @@ class WorkingMemory:
Returns:
Checkpoint ID for later restoration
"""
checkpoint_id = str(uuid.uuid4())[:8]
# Use full UUID to avoid collision risk (8 chars has ~50k collision at birthday paradox)
checkpoint_id = str(uuid.uuid4())
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
# Capture all current state

View File

@@ -738,26 +738,32 @@ class TestComprehensiveReflection:
assert "Episodes analyzed" in summary
class TestSingleton:
"""Tests for singleton pattern."""
class TestFactoryFunction:
"""Tests for factory function behavior.
async def test_get_memory_reflection_returns_singleton(
Note: The singleton pattern was removed to avoid stale database session bugs.
Each call now creates a fresh instance, which is safer for request-scoped usage.
"""
async def test_get_memory_reflection_creates_new_instance(
self,
mock_session: MagicMock,
) -> None:
"""Should return same instance."""
"""Should create new instance each call (no singleton for session safety)."""
r1 = await get_memory_reflection(mock_session)
r2 = await get_memory_reflection(mock_session)
assert r1 is r2
async def test_reset_creates_new_instance(
self,
mock_session: MagicMock,
) -> None:
"""Should create new instance after reset."""
r1 = await get_memory_reflection(mock_session)
await reset_memory_reflection()
r2 = await get_memory_reflection(mock_session)
# Different instances to avoid stale session issues
assert r1 is not r2
async def test_reset_is_no_op(
self,
mock_session: MagicMock,
) -> None:
"""Reset should be a no-op (kept for API compatibility)."""
r1 = await get_memory_reflection(mock_session)
await reset_memory_reflection() # Should not raise
r2 = await get_memory_reflection(mock_session)
# Still creates new instances (reset is no-op now)
assert r1 is not r2

View File

@@ -276,7 +276,7 @@ class TestWorkingMemoryCheckpoints:
checkpoint_id = await memory.create_checkpoint("Test checkpoint")
assert checkpoint_id is not None
assert len(checkpoint_id) == 8 # UUID prefix
assert len(checkpoint_id) == 36 # Full UUID for collision safety
@pytest.mark.asyncio
async def test_restore_checkpoint(self, memory: WorkingMemory) -> None: