2 Commits

Author SHA1 Message Date
Felipe Cardoso
192237e69b fix(memory): unify Outcome enum and add ABANDONED support
- Add ABANDONED value to core Outcome enum in types.py
- Replace duplicate OutcomeType class in mcp/tools.py with alias to Outcome
- Simplify mcp/service.py to use outcome directly (no more silent mapping)
- Add migration 0006 to extend PostgreSQL episode_outcome enum
- Add missing constraints to migration 0005 (ix_facts_unique_triple_global)

This fixes the semantic issue where ABANDONED outcomes were silently
converted to FAILURE, losing information about task abandonment.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 01:46:48 +01:00
Felipe Cardoso
3edce9cd26 fix(memory): address critical bugs from multi-agent review
Bug Fixes:
- Remove singleton pattern from consolidation/reflection services to
  prevent stale database session bugs (session is now passed per-request)
- Add LRU eviction to MemoryToolService._working dict (max 1000 sessions)
  to prevent unbounded memory growth
- Replace O(n) list.remove() with O(1) OrderedDict.move_to_end() in
  RetrievalCache for better performance under load
- Use deque with maxlen for metrics histograms to prevent unbounded
  memory growth (circular buffer with 10k max samples)
- Use full UUID for checkpoint IDs instead of 8-char prefix to avoid
  collision risk at scale (birthday paradox at ~50k checkpoints)

Test Updates:
- Update checkpoint test to expect 36-char UUID
- Update reflection singleton tests to expect new factory behavior
- Add reset_memory_reflection() no-op for backwards compatibility

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 18:55:32 +01:00
12 changed files with 161 additions and 96 deletions

View File

@@ -300,6 +300,14 @@ def upgrade() -> None:
unique=True, unique=True,
postgresql_where=sa.text("project_id IS NOT NULL"), postgresql_where=sa.text("project_id IS NOT NULL"),
) )
# Unique constraint for global facts (project_id IS NULL)
op.create_index(
"ix_facts_unique_triple_global",
"facts",
["subject", "predicate", "object"],
unique=True,
postgresql_where=sa.text("project_id IS NULL"),
)
# ========================================================================= # =========================================================================
# Create procedures table # Create procedures table
@@ -396,6 +404,11 @@ def upgrade() -> None:
"facts", "facts",
"confidence >= 0.0 AND confidence <= 1.0", "confidence >= 0.0 AND confidence <= 1.0",
) )
op.create_check_constraint(
"ck_facts_reinforcement_positive",
"facts",
"reinforcement_count >= 1",
)
# Procedure constraints # Procedure constraints
op.create_check_constraint( op.create_check_constraint(
@@ -476,11 +489,15 @@ def downgrade() -> None:
# Drop check constraints first # Drop check constraints first
op.drop_constraint("ck_procedures_failure_positive", "procedures", type_="check") op.drop_constraint("ck_procedures_failure_positive", "procedures", type_="check")
op.drop_constraint("ck_procedures_success_positive", "procedures", type_="check") op.drop_constraint("ck_procedures_success_positive", "procedures", type_="check")
op.drop_constraint("ck_facts_reinforcement_positive", "facts", type_="check")
op.drop_constraint("ck_facts_confidence_range", "facts", type_="check") op.drop_constraint("ck_facts_confidence_range", "facts", type_="check")
op.drop_constraint("ck_episodes_tokens_positive", "episodes", type_="check") op.drop_constraint("ck_episodes_tokens_positive", "episodes", type_="check")
op.drop_constraint("ck_episodes_duration_positive", "episodes", type_="check") op.drop_constraint("ck_episodes_duration_positive", "episodes", type_="check")
op.drop_constraint("ck_episodes_importance_range", "episodes", type_="check") op.drop_constraint("ck_episodes_importance_range", "episodes", type_="check")
# Drop unique indexes for global facts
op.drop_index("ix_facts_unique_triple_global", "facts")
# Drop tables in reverse order (dependencies first) # Drop tables in reverse order (dependencies first)
op.drop_table("memory_consolidation_log") op.drop_table("memory_consolidation_log")
op.drop_table("procedures") op.drop_table("procedures")

View File

@@ -0,0 +1,52 @@
"""Add ABANDONED to episode_outcome enum
Revision ID: 0006
Revises: 0005
Create Date: 2025-01-06
This migration adds the 'abandoned' value to the episode_outcome enum type.
This allows episodes to track when a task was abandoned (not completed,
but not necessarily a failure either - e.g., user cancelled, session timeout).
"""
from collections.abc import Sequence
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "0006"
down_revision: str | None = "0005"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add 'abandoned' value to episode_outcome enum."""
# PostgreSQL ALTER TYPE ADD VALUE is safe and non-blocking
op.execute("ALTER TYPE episode_outcome ADD VALUE IF NOT EXISTS 'abandoned'")
def downgrade() -> None:
"""Remove 'abandoned' from episode_outcome enum.
Note: PostgreSQL doesn't support removing values from enums directly.
This downgrade converts any 'abandoned' episodes to 'failure' and
recreates the enum without 'abandoned'.
"""
# Convert any abandoned episodes to failure first
op.execute("""
UPDATE episodes
SET outcome = 'failure'
WHERE outcome = 'abandoned'
""")
# Recreate the enum without abandoned
# This is complex in PostgreSQL - requires creating new type, updating columns, dropping old
op.execute("ALTER TYPE episode_outcome RENAME TO episode_outcome_old")
op.execute("CREATE TYPE episode_outcome AS ENUM ('success', 'failure', 'partial')")
op.execute("""
ALTER TABLE episodes
ALTER COLUMN outcome TYPE episode_outcome
USING outcome::text::episode_outcome
""")
op.execute("DROP TYPE episode_outcome_old")

View File

@@ -892,27 +892,22 @@ class MemoryConsolidationService:
return result return result
# Singleton instance # Factory function - no singleton to avoid stale session issues
_consolidation_service: MemoryConsolidationService | None = None
async def get_consolidation_service( async def get_consolidation_service(
session: AsyncSession, session: AsyncSession,
config: ConsolidationConfig | None = None, config: ConsolidationConfig | None = None,
) -> MemoryConsolidationService: ) -> MemoryConsolidationService:
""" """
Get or create the memory consolidation service. Create a memory consolidation service for the given session.
Note: This creates a new instance each time to avoid stale session issues.
The service is lightweight and safe to recreate per-request.
Args: Args:
session: Database session session: Database session (must be active)
config: Optional configuration config: Optional configuration
Returns: Returns:
MemoryConsolidationService instance MemoryConsolidationService instance
""" """
global _consolidation_service return MemoryConsolidationService(session=session, config=config)
if _consolidation_service is None:
_consolidation_service = MemoryConsolidationService(
session=session, config=config
)
return _consolidation_service

View File

@@ -13,6 +13,7 @@ Provides hybrid retrieval capabilities combining:
import hashlib import hashlib
import logging import logging
from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, TypeVar from typing import Any, TypeVar
@@ -243,7 +244,8 @@ class RetrievalCache:
""" """
In-memory cache for retrieval results. In-memory cache for retrieval results.
Supports TTL-based expiration and LRU eviction. Supports TTL-based expiration and LRU eviction with O(1) operations.
Uses OrderedDict for efficient LRU tracking.
""" """
def __init__( def __init__(
@@ -258,10 +260,10 @@ class RetrievalCache:
max_entries: Maximum cache entries max_entries: Maximum cache entries
default_ttl_seconds: Default TTL for entries default_ttl_seconds: Default TTL for entries
""" """
self._cache: dict[str, CacheEntry] = {} # OrderedDict maintains insertion order; we use move_to_end for O(1) LRU
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._max_entries = max_entries self._max_entries = max_entries
self._default_ttl = default_ttl_seconds self._default_ttl = default_ttl_seconds
self._access_order: list[str] = []
logger.info( logger.info(
f"Initialized RetrievalCache with max_entries={max_entries}, " f"Initialized RetrievalCache with max_entries={max_entries}, "
f"ttl={default_ttl_seconds}s" f"ttl={default_ttl_seconds}s"
@@ -283,14 +285,10 @@ class RetrievalCache:
entry = self._cache[query_key] entry = self._cache[query_key]
if entry.is_expired(): if entry.is_expired():
del self._cache[query_key] del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return None return None
# Update access order (LRU) # Update access order (LRU) - O(1) with OrderedDict
if query_key in self._access_order: self._cache.move_to_end(query_key)
self._access_order.remove(query_key)
self._access_order.append(query_key)
logger.debug(f"Cache hit for {query_key}") logger.debug(f"Cache hit for {query_key}")
return entry.results return entry.results
@@ -309,11 +307,9 @@ class RetrievalCache:
results: Results to cache results: Results to cache
ttl_seconds: TTL for this entry (or default) ttl_seconds: TTL for this entry (or default)
""" """
# Evict if at capacity # Evict oldest entries if at capacity - O(1) with popitem(last=False)
while len(self._cache) >= self._max_entries and self._access_order: while len(self._cache) >= self._max_entries:
oldest_key = self._access_order.pop(0) self._cache.popitem(last=False)
if oldest_key in self._cache:
del self._cache[oldest_key]
entry = CacheEntry( entry = CacheEntry(
results=results, results=results,
@@ -323,7 +319,6 @@ class RetrievalCache:
) )
self._cache[query_key] = entry self._cache[query_key] = entry
self._access_order.append(query_key)
logger.debug(f"Cached {len(results)} results for {query_key}") logger.debug(f"Cached {len(results)} results for {query_key}")
def invalidate(self, query_key: str) -> bool: def invalidate(self, query_key: str) -> bool:
@@ -338,8 +333,6 @@ class RetrievalCache:
""" """
if query_key in self._cache: if query_key in self._cache:
del self._cache[query_key] del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return True return True
return False return False
@@ -376,7 +369,6 @@ class RetrievalCache:
""" """
count = len(self._cache) count = len(self._cache)
self._cache.clear() self._cache.clear()
self._access_order.clear()
logger.info(f"Cleared {count} cache entries") logger.info(f"Cleared {count} cache entries")
return count return count

View File

@@ -7,6 +7,7 @@ All tools are scoped to project/agent context for proper isolation.
""" """
import logging import logging
from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Any from typing import Any
@@ -83,6 +84,9 @@ class MemoryToolService:
This service coordinates between different memory types. This service coordinates between different memory types.
""" """
# Maximum number of working memory sessions to cache (LRU eviction)
MAX_WORKING_SESSIONS = 1000
def __init__( def __init__(
self, self,
session: AsyncSession, session: AsyncSession,
@@ -98,8 +102,8 @@ class MemoryToolService:
self._session = session self._session = session
self._embedding_generator = embedding_generator self._embedding_generator = embedding_generator
# Lazy-initialized memory services # Lazy-initialized memory services with LRU eviction for working memory
self._working: dict[str, WorkingMemory] = {} # keyed by session_id self._working: OrderedDict[str, WorkingMemory] = OrderedDict()
self._episodic: EpisodicMemory | None = None self._episodic: EpisodicMemory | None = None
self._semantic: SemanticMemory | None = None self._semantic: SemanticMemory | None = None
self._procedural: ProceduralMemory | None = None self._procedural: ProceduralMemory | None = None
@@ -110,14 +114,28 @@ class MemoryToolService:
project_id: UUID | None = None, project_id: UUID | None = None,
agent_instance_id: UUID | None = None, agent_instance_id: UUID | None = None,
) -> WorkingMemory: ) -> WorkingMemory:
"""Get or create working memory for a session.""" """Get or create working memory for a session with LRU eviction."""
if session_id not in self._working: if session_id in self._working:
self._working[session_id] = await WorkingMemory.for_session( # Move to end (most recently used)
session_id=session_id, self._working.move_to_end(session_id)
project_id=str(project_id) if project_id else None, return self._working[session_id]
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
) # Evict oldest entries if at capacity
return self._working[session_id] while len(self._working) >= self.MAX_WORKING_SESSIONS:
oldest_id, oldest_memory = self._working.popitem(last=False)
try:
await oldest_memory.close()
except Exception as e:
logger.warning(f"Error closing evicted working memory {oldest_id}: {e}")
# Create new working memory
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id) if project_id else None,
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
self._working[session_id] = working
return working
async def _get_episodic(self) -> EpisodicMemory: async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service.""" """Get or create episodic memory service."""
@@ -1006,15 +1024,8 @@ class MemoryToolService:
context: ToolContext, context: ToolContext,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Execute the 'record_outcome' tool.""" """Execute the 'record_outcome' tool."""
# Map outcome type to memory Outcome # OutcomeType is now an alias for Outcome, use directly
# Note: ABANDONED maps to FAILURE since core Outcome doesn't have ABANDONED outcome = args.outcome
outcome_map = {
OutcomeType.SUCCESS: Outcome.SUCCESS,
OutcomeType.PARTIAL: Outcome.PARTIAL,
OutcomeType.FAILURE: Outcome.FAILURE,
OutcomeType.ABANDONED: Outcome.FAILURE, # No ABANDONED in core enum
}
outcome = outcome_map.get(args.outcome, Outcome.FAILURE)
# Record in episodic memory # Record in episodic memory
episodic = await self._get_episodic() episodic = await self._get_episodic()

View File

@@ -12,6 +12,9 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# OutcomeType alias - uses core Outcome enum from types module for consistency
from app.services.memory.types import Outcome as OutcomeType
class MemoryType(str, Enum): class MemoryType(str, Enum):
"""Types of memory for storage operations.""" """Types of memory for storage operations."""
@@ -32,15 +35,6 @@ class AnalysisType(str, Enum):
LEARNING_PROGRESS = "learning_progress" LEARNING_PROGRESS = "learning_progress"
class OutcomeType(str, Enum):
"""Outcome types for record_outcome tool."""
SUCCESS = "success"
PARTIAL = "partial"
FAILURE = "failure"
ABANDONED = "abandoned"
# ============================================================================ # ============================================================================
# Tool Argument Schemas (Pydantic models for validation) # Tool Argument Schemas (Pydantic models for validation)
# ============================================================================ # ============================================================================

View File

@@ -7,7 +7,7 @@ Collects and exposes metrics for the memory system.
import asyncio import asyncio
import logging import logging
from collections import Counter, defaultdict from collections import Counter, defaultdict, deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum from enum import Enum
@@ -57,11 +57,17 @@ class MemoryMetrics:
- Embedding operations - Embedding operations
""" """
# Maximum samples to keep in histogram (circular buffer)
MAX_HISTOGRAM_SAMPLES = 10000
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize MemoryMetrics.""" """Initialize MemoryMetrics."""
self._counters: dict[str, Counter[str]] = defaultdict(Counter) self._counters: dict[str, Counter[str]] = defaultdict(Counter)
self._gauges: dict[str, dict[str, float]] = defaultdict(dict) self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
self._histograms: dict[str, list[float]] = defaultdict(list) # Use deque with maxlen for bounded memory (circular buffer)
self._histograms: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=self.MAX_HISTOGRAM_SAMPLES)
)
self._histogram_buckets: dict[str, list[HistogramBucket]] = {} self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
self._lock = asyncio.Lock() self._lock = asyncio.Lock()

View File

@@ -7,7 +7,6 @@ Implements pattern detection, success/failure analysis, anomaly detection,
and insight generation. and insight generation.
""" """
import asyncio
import logging import logging
import statistics import statistics
from collections import Counter, defaultdict from collections import Counter, defaultdict
@@ -1426,36 +1425,27 @@ class MemoryReflection:
) )
# Singleton instance with async-safe initialization # Factory function - no singleton to avoid stale session issues
_memory_reflection: MemoryReflection | None = None
_reflection_lock = asyncio.Lock()
async def get_memory_reflection( async def get_memory_reflection(
session: AsyncSession, session: AsyncSession,
config: ReflectionConfig | None = None, config: ReflectionConfig | None = None,
) -> MemoryReflection: ) -> MemoryReflection:
""" """
Get or create the memory reflection service (async-safe). Create a memory reflection service for the given session.
Note: This creates a new instance each time to avoid stale session issues.
The service is lightweight and safe to recreate per-request.
Args: Args:
session: Database session session: Database session (must be active)
config: Optional configuration config: Optional configuration
Returns: Returns:
MemoryReflection instance MemoryReflection instance
""" """
global _memory_reflection return MemoryReflection(session=session, config=config)
if _memory_reflection is None:
async with _reflection_lock:
# Double-check locking pattern
if _memory_reflection is None:
_memory_reflection = MemoryReflection(session=session, config=config)
return _memory_reflection
async def reset_memory_reflection() -> None: async def reset_memory_reflection() -> None:
"""Reset the memory reflection singleton (async-safe).""" """No-op for backwards compatibility (singleton pattern removed)."""
global _memory_reflection return
async with _reflection_lock:
_memory_reflection = None

View File

@@ -42,6 +42,7 @@ class Outcome(str, Enum):
SUCCESS = "success" SUCCESS = "success"
FAILURE = "failure" FAILURE = "failure"
PARTIAL = "partial" PARTIAL = "partial"
ABANDONED = "abandoned"
class ConsolidationStatus(str, Enum): class ConsolidationStatus(str, Enum):

View File

@@ -423,7 +423,8 @@ class WorkingMemory:
Returns: Returns:
Checkpoint ID for later restoration Checkpoint ID for later restoration
""" """
checkpoint_id = str(uuid.uuid4())[:8] # Use full UUID to avoid collision risk (8 chars has ~50k collision at birthday paradox)
checkpoint_id = str(uuid.uuid4())
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}" checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
# Capture all current state # Capture all current state

View File

@@ -738,26 +738,32 @@ class TestComprehensiveReflection:
assert "Episodes analyzed" in summary assert "Episodes analyzed" in summary
class TestSingleton: class TestFactoryFunction:
"""Tests for singleton pattern.""" """Tests for factory function behavior.
async def test_get_memory_reflection_returns_singleton( Note: The singleton pattern was removed to avoid stale database session bugs.
Each call now creates a fresh instance, which is safer for request-scoped usage.
"""
async def test_get_memory_reflection_creates_new_instance(
self, self,
mock_session: MagicMock, mock_session: MagicMock,
) -> None: ) -> None:
"""Should return same instance.""" """Should create new instance each call (no singleton for session safety)."""
r1 = await get_memory_reflection(mock_session) r1 = await get_memory_reflection(mock_session)
r2 = await get_memory_reflection(mock_session) r2 = await get_memory_reflection(mock_session)
assert r1 is r2 # Different instances to avoid stale session issues
assert r1 is not r2
async def test_reset_creates_new_instance(
self, async def test_reset_is_no_op(
mock_session: MagicMock, self,
) -> None: mock_session: MagicMock,
"""Should create new instance after reset.""" ) -> None:
r1 = await get_memory_reflection(mock_session) """Reset should be a no-op (kept for API compatibility)."""
await reset_memory_reflection() r1 = await get_memory_reflection(mock_session)
r2 = await get_memory_reflection(mock_session) await reset_memory_reflection() # Should not raise
r2 = await get_memory_reflection(mock_session)
# Still creates new instances (reset is no-op now)
assert r1 is not r2 assert r1 is not r2

View File

@@ -276,7 +276,7 @@ class TestWorkingMemoryCheckpoints:
checkpoint_id = await memory.create_checkpoint("Test checkpoint") checkpoint_id = await memory.create_checkpoint("Test checkpoint")
assert checkpoint_id is not None assert checkpoint_id is not None
assert len(checkpoint_id) == 8 # UUID prefix assert len(checkpoint_id) == 36 # Full UUID for collision safety
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_checkpoint(self, memory: WorkingMemory) -> None: async def test_restore_checkpoint(self, memory: WorkingMemory) -> None: