fix(memory): address critical bugs from multi-agent review

Bug Fixes:
- Remove singleton pattern from consolidation/reflection services to
  prevent stale database session bugs (session is now passed per-request)
- Add LRU eviction to MemoryToolService._working dict (max 1000 sessions)
  to prevent unbounded memory growth
- Replace O(n) list.remove() with O(1) OrderedDict.move_to_end() in
  RetrievalCache for better performance under load
- Use deque with maxlen for metrics histograms to prevent unbounded
  memory growth (circular buffer with 10k max samples)
- Use full UUID for checkpoint IDs instead of 8-char prefix to avoid
  collision risk at scale (birthday paradox at ~50k checkpoints)

Test Updates:
- Update checkpoint test to expect 36-char UUID
- Update reflection singleton tests to expect new factory behavior
- Add reset_memory_reflection() no-op for backwards compatibility

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 18:55:32 +01:00
parent 35aea2d73a
commit 3edce9cd26
8 changed files with 86 additions and 78 deletions

View File

@@ -892,27 +892,22 @@ class MemoryConsolidationService:
return result return result
# Singleton instance # Factory function - no singleton to avoid stale session issues
_consolidation_service: MemoryConsolidationService | None = None
async def get_consolidation_service( async def get_consolidation_service(
session: AsyncSession, session: AsyncSession,
config: ConsolidationConfig | None = None, config: ConsolidationConfig | None = None,
) -> MemoryConsolidationService: ) -> MemoryConsolidationService:
""" """
Get or create the memory consolidation service. Create a memory consolidation service for the given session.
Note: This creates a new instance each time to avoid stale session issues.
The service is lightweight and safe to recreate per-request.
Args: Args:
session: Database session session: Database session (must be active)
config: Optional configuration config: Optional configuration
Returns: Returns:
MemoryConsolidationService instance MemoryConsolidationService instance
""" """
global _consolidation_service return MemoryConsolidationService(session=session, config=config)
if _consolidation_service is None:
_consolidation_service = MemoryConsolidationService(
session=session, config=config
)
return _consolidation_service

View File

@@ -13,6 +13,7 @@ Provides hybrid retrieval capabilities combining:
import hashlib import hashlib
import logging import logging
from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, TypeVar from typing import Any, TypeVar
@@ -243,7 +244,8 @@ class RetrievalCache:
""" """
In-memory cache for retrieval results. In-memory cache for retrieval results.
Supports TTL-based expiration and LRU eviction. Supports TTL-based expiration and LRU eviction with O(1) operations.
Uses OrderedDict for efficient LRU tracking.
""" """
def __init__( def __init__(
@@ -258,10 +260,10 @@ class RetrievalCache:
max_entries: Maximum cache entries max_entries: Maximum cache entries
default_ttl_seconds: Default TTL for entries default_ttl_seconds: Default TTL for entries
""" """
self._cache: dict[str, CacheEntry] = {} # OrderedDict maintains insertion order; we use move_to_end for O(1) LRU
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._max_entries = max_entries self._max_entries = max_entries
self._default_ttl = default_ttl_seconds self._default_ttl = default_ttl_seconds
self._access_order: list[str] = []
logger.info( logger.info(
f"Initialized RetrievalCache with max_entries={max_entries}, " f"Initialized RetrievalCache with max_entries={max_entries}, "
f"ttl={default_ttl_seconds}s" f"ttl={default_ttl_seconds}s"
@@ -283,14 +285,10 @@ class RetrievalCache:
entry = self._cache[query_key] entry = self._cache[query_key]
if entry.is_expired(): if entry.is_expired():
del self._cache[query_key] del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return None return None
# Update access order (LRU) # Update access order (LRU) - O(1) with OrderedDict
if query_key in self._access_order: self._cache.move_to_end(query_key)
self._access_order.remove(query_key)
self._access_order.append(query_key)
logger.debug(f"Cache hit for {query_key}") logger.debug(f"Cache hit for {query_key}")
return entry.results return entry.results
@@ -309,11 +307,9 @@ class RetrievalCache:
results: Results to cache results: Results to cache
ttl_seconds: TTL for this entry (or default) ttl_seconds: TTL for this entry (or default)
""" """
# Evict if at capacity # Evict oldest entries if at capacity - O(1) with popitem(last=False)
while len(self._cache) >= self._max_entries and self._access_order: while len(self._cache) >= self._max_entries:
oldest_key = self._access_order.pop(0) self._cache.popitem(last=False)
if oldest_key in self._cache:
del self._cache[oldest_key]
entry = CacheEntry( entry = CacheEntry(
results=results, results=results,
@@ -323,7 +319,6 @@ class RetrievalCache:
) )
self._cache[query_key] = entry self._cache[query_key] = entry
self._access_order.append(query_key)
logger.debug(f"Cached {len(results)} results for {query_key}") logger.debug(f"Cached {len(results)} results for {query_key}")
def invalidate(self, query_key: str) -> bool: def invalidate(self, query_key: str) -> bool:
@@ -338,8 +333,6 @@ class RetrievalCache:
""" """
if query_key in self._cache: if query_key in self._cache:
del self._cache[query_key] del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return True return True
return False return False
@@ -376,7 +369,6 @@ class RetrievalCache:
""" """
count = len(self._cache) count = len(self._cache)
self._cache.clear() self._cache.clear()
self._access_order.clear()
logger.info(f"Cleared {count} cache entries") logger.info(f"Cleared {count} cache entries")
return count return count

View File

@@ -7,6 +7,7 @@ All tools are scoped to project/agent context for proper isolation.
""" """
import logging import logging
from collections import OrderedDict
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from typing import Any from typing import Any
@@ -83,6 +84,9 @@ class MemoryToolService:
This service coordinates between different memory types. This service coordinates between different memory types.
""" """
# Maximum number of working memory sessions to cache (LRU eviction)
MAX_WORKING_SESSIONS = 1000
def __init__( def __init__(
self, self,
session: AsyncSession, session: AsyncSession,
@@ -98,8 +102,8 @@ class MemoryToolService:
self._session = session self._session = session
self._embedding_generator = embedding_generator self._embedding_generator = embedding_generator
# Lazy-initialized memory services # Lazy-initialized memory services with LRU eviction for working memory
self._working: dict[str, WorkingMemory] = {} # keyed by session_id self._working: OrderedDict[str, WorkingMemory] = OrderedDict()
self._episodic: EpisodicMemory | None = None self._episodic: EpisodicMemory | None = None
self._semantic: SemanticMemory | None = None self._semantic: SemanticMemory | None = None
self._procedural: ProceduralMemory | None = None self._procedural: ProceduralMemory | None = None
@@ -110,14 +114,28 @@ class MemoryToolService:
project_id: UUID | None = None, project_id: UUID | None = None,
agent_instance_id: UUID | None = None, agent_instance_id: UUID | None = None,
) -> WorkingMemory: ) -> WorkingMemory:
"""Get or create working memory for a session.""" """Get or create working memory for a session with LRU eviction."""
if session_id not in self._working: if session_id in self._working:
self._working[session_id] = await WorkingMemory.for_session( # Move to end (most recently used)
session_id=session_id, self._working.move_to_end(session_id)
project_id=str(project_id) if project_id else None, return self._working[session_id]
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
) # Evict oldest entries if at capacity
return self._working[session_id] while len(self._working) >= self.MAX_WORKING_SESSIONS:
oldest_id, oldest_memory = self._working.popitem(last=False)
try:
await oldest_memory.close()
except Exception as e:
logger.warning(f"Error closing evicted working memory {oldest_id}: {e}")
# Create new working memory
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id) if project_id else None,
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
self._working[session_id] = working
return working
async def _get_episodic(self) -> EpisodicMemory: async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service.""" """Get or create episodic memory service."""

View File

@@ -7,7 +7,7 @@ Collects and exposes metrics for the memory system.
import asyncio import asyncio
import logging import logging
from collections import Counter, defaultdict from collections import Counter, defaultdict, deque
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum from enum import Enum
@@ -57,11 +57,17 @@ class MemoryMetrics:
- Embedding operations - Embedding operations
""" """
# Maximum samples to keep in histogram (circular buffer)
MAX_HISTOGRAM_SAMPLES = 10000
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize MemoryMetrics.""" """Initialize MemoryMetrics."""
self._counters: dict[str, Counter[str]] = defaultdict(Counter) self._counters: dict[str, Counter[str]] = defaultdict(Counter)
self._gauges: dict[str, dict[str, float]] = defaultdict(dict) self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
self._histograms: dict[str, list[float]] = defaultdict(list) # Use deque with maxlen for bounded memory (circular buffer)
self._histograms: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=self.MAX_HISTOGRAM_SAMPLES)
)
self._histogram_buckets: dict[str, list[HistogramBucket]] = {} self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
self._lock = asyncio.Lock() self._lock = asyncio.Lock()

View File

@@ -7,7 +7,6 @@ Implements pattern detection, success/failure analysis, anomaly detection,
and insight generation. and insight generation.
""" """
import asyncio
import logging import logging
import statistics import statistics
from collections import Counter, defaultdict from collections import Counter, defaultdict
@@ -1426,36 +1425,27 @@ class MemoryReflection:
) )
# Singleton instance with async-safe initialization # Factory function - no singleton to avoid stale session issues
_memory_reflection: MemoryReflection | None = None
_reflection_lock = asyncio.Lock()
async def get_memory_reflection( async def get_memory_reflection(
session: AsyncSession, session: AsyncSession,
config: ReflectionConfig | None = None, config: ReflectionConfig | None = None,
) -> MemoryReflection: ) -> MemoryReflection:
""" """
Get or create the memory reflection service (async-safe). Create a memory reflection service for the given session.
Note: This creates a new instance each time to avoid stale session issues.
The service is lightweight and safe to recreate per-request.
Args: Args:
session: Database session session: Database session (must be active)
config: Optional configuration config: Optional configuration
Returns: Returns:
MemoryReflection instance MemoryReflection instance
""" """
global _memory_reflection return MemoryReflection(session=session, config=config)
if _memory_reflection is None:
async with _reflection_lock:
# Double-check locking pattern
if _memory_reflection is None:
_memory_reflection = MemoryReflection(session=session, config=config)
return _memory_reflection
async def reset_memory_reflection() -> None: async def reset_memory_reflection() -> None:
"""Reset the memory reflection singleton (async-safe).""" """No-op for backwards compatibility (singleton pattern removed)."""
global _memory_reflection return
async with _reflection_lock:
_memory_reflection = None

View File

@@ -423,7 +423,8 @@ class WorkingMemory:
Returns: Returns:
Checkpoint ID for later restoration Checkpoint ID for later restoration
""" """
checkpoint_id = str(uuid.uuid4())[:8] # Use full UUID to avoid collision risk (8 chars has ~50k collision at birthday paradox)
checkpoint_id = str(uuid.uuid4())
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}" checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
# Capture all current state # Capture all current state

View File

@@ -738,26 +738,32 @@ class TestComprehensiveReflection:
assert "Episodes analyzed" in summary assert "Episodes analyzed" in summary
class TestSingleton: class TestFactoryFunction:
"""Tests for singleton pattern.""" """Tests for factory function behavior.
async def test_get_memory_reflection_returns_singleton( Note: The singleton pattern was removed to avoid stale database session bugs.
Each call now creates a fresh instance, which is safer for request-scoped usage.
"""
async def test_get_memory_reflection_creates_new_instance(
self, self,
mock_session: MagicMock, mock_session: MagicMock,
) -> None: ) -> None:
"""Should return same instance.""" """Should create new instance each call (no singleton for session safety)."""
r1 = await get_memory_reflection(mock_session) r1 = await get_memory_reflection(mock_session)
r2 = await get_memory_reflection(mock_session) r2 = await get_memory_reflection(mock_session)
assert r1 is r2 # Different instances to avoid stale session issues
assert r1 is not r2
async def test_reset_creates_new_instance(
self, async def test_reset_is_no_op(
mock_session: MagicMock, self,
) -> None: mock_session: MagicMock,
"""Should create new instance after reset.""" ) -> None:
r1 = await get_memory_reflection(mock_session) """Reset should be a no-op (kept for API compatibility)."""
await reset_memory_reflection() r1 = await get_memory_reflection(mock_session)
r2 = await get_memory_reflection(mock_session) await reset_memory_reflection() # Should not raise
r2 = await get_memory_reflection(mock_session)
# Still creates new instances (reset is no-op now)
assert r1 is not r2 assert r1 is not r2

View File

@@ -276,7 +276,7 @@ class TestWorkingMemoryCheckpoints:
checkpoint_id = await memory.create_checkpoint("Test checkpoint") checkpoint_id = await memory.create_checkpoint("Test checkpoint")
assert checkpoint_id is not None assert checkpoint_id is not None
assert len(checkpoint_id) == 8 # UUID prefix assert len(checkpoint_id) == 36 # Full UUID for collision safety
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_restore_checkpoint(self, memory: WorkingMemory) -> None: async def test_restore_checkpoint(self, memory: WorkingMemory) -> None: