fix(memory): address critical bugs from multi-agent review

Bug Fixes:
- Remove singleton pattern from consolidation/reflection services to
  prevent stale database session bugs (session is now passed per-request)
- Add LRU eviction to MemoryToolService._working dict (max 1000 sessions)
  to prevent unbounded memory growth
- Replace O(n) list.remove() with O(1) OrderedDict.move_to_end() in
  RetrievalCache for better performance under load
- Use deque with maxlen for metrics histograms to prevent unbounded
  memory growth (circular buffer with 10k max samples)
- Use full UUID for checkpoint IDs instead of 8-char prefix to avoid
  collision risk at scale (birthday paradox at ~50k checkpoints)

Test Updates:
- Update checkpoint test to expect 36-char UUID
- Update reflection singleton tests to expect new factory behavior
- Add reset_memory_reflection() no-op for backwards compatibility

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 18:55:32 +01:00
parent 35aea2d73a
commit 3edce9cd26
8 changed files with 86 additions and 78 deletions

View File

@@ -892,27 +892,22 @@ class MemoryConsolidationService:
return result
# Singleton instance
_consolidation_service: MemoryConsolidationService | None = None
# Factory function - no singleton to avoid stale session issues
async def get_consolidation_service(
session: AsyncSession,
config: ConsolidationConfig | None = None,
) -> MemoryConsolidationService:
"""
Get or create the memory consolidation service.
Create a memory consolidation service for the given session.
Note: This creates a new instance each time to avoid stale session issues.
The service is lightweight and safe to recreate per-request.
Args:
session: Database session
session: Database session (must be active)
config: Optional configuration
Returns:
MemoryConsolidationService instance
"""
global _consolidation_service
if _consolidation_service is None:
_consolidation_service = MemoryConsolidationService(
session=session, config=config
)
return _consolidation_service
return MemoryConsolidationService(session=session, config=config)

View File

@@ -13,6 +13,7 @@ Provides hybrid retrieval capabilities combining:
import hashlib
import logging
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any, TypeVar
@@ -243,7 +244,8 @@ class RetrievalCache:
"""
In-memory cache for retrieval results.
Supports TTL-based expiration and LRU eviction.
Supports TTL-based expiration and LRU eviction with O(1) operations.
Uses OrderedDict for efficient LRU tracking.
"""
def __init__(
@@ -258,10 +260,10 @@ class RetrievalCache:
max_entries: Maximum cache entries
default_ttl_seconds: Default TTL for entries
"""
self._cache: dict[str, CacheEntry] = {}
# OrderedDict maintains insertion order; we use move_to_end for O(1) LRU
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._max_entries = max_entries
self._default_ttl = default_ttl_seconds
self._access_order: list[str] = []
logger.info(
f"Initialized RetrievalCache with max_entries={max_entries}, "
f"ttl={default_ttl_seconds}s"
@@ -283,14 +285,10 @@ class RetrievalCache:
entry = self._cache[query_key]
if entry.is_expired():
del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return None
# Update access order (LRU)
if query_key in self._access_order:
self._access_order.remove(query_key)
self._access_order.append(query_key)
# Update access order (LRU) - O(1) with OrderedDict
self._cache.move_to_end(query_key)
logger.debug(f"Cache hit for {query_key}")
return entry.results
@@ -309,11 +307,9 @@ class RetrievalCache:
results: Results to cache
ttl_seconds: TTL for this entry (or default)
"""
# Evict if at capacity
while len(self._cache) >= self._max_entries and self._access_order:
oldest_key = self._access_order.pop(0)
if oldest_key in self._cache:
del self._cache[oldest_key]
# Evict oldest entries if at capacity - O(1) with popitem(last=False)
while len(self._cache) >= self._max_entries:
self._cache.popitem(last=False)
entry = CacheEntry(
results=results,
@@ -323,7 +319,6 @@ class RetrievalCache:
)
self._cache[query_key] = entry
self._access_order.append(query_key)
logger.debug(f"Cached {len(results)} results for {query_key}")
def invalidate(self, query_key: str) -> bool:
@@ -338,8 +333,6 @@ class RetrievalCache:
"""
if query_key in self._cache:
del self._cache[query_key]
if query_key in self._access_order:
self._access_order.remove(query_key)
return True
return False
@@ -376,7 +369,6 @@ class RetrievalCache:
"""
count = len(self._cache)
self._cache.clear()
self._access_order.clear()
logger.info(f"Cleared {count} cache entries")
return count

View File

@@ -7,6 +7,7 @@ All tools are scoped to project/agent context for proper isolation.
"""
import logging
from collections import OrderedDict
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from typing import Any
@@ -83,6 +84,9 @@ class MemoryToolService:
This service coordinates between different memory types.
"""
# Maximum number of working memory sessions to cache (LRU eviction)
MAX_WORKING_SESSIONS = 1000
def __init__(
self,
session: AsyncSession,
@@ -98,8 +102,8 @@ class MemoryToolService:
self._session = session
self._embedding_generator = embedding_generator
# Lazy-initialized memory services
self._working: dict[str, WorkingMemory] = {} # keyed by session_id
# Lazy-initialized memory services with LRU eviction for working memory
self._working: OrderedDict[str, WorkingMemory] = OrderedDict()
self._episodic: EpisodicMemory | None = None
self._semantic: SemanticMemory | None = None
self._procedural: ProceduralMemory | None = None
@@ -110,14 +114,28 @@ class MemoryToolService:
project_id: UUID | None = None,
agent_instance_id: UUID | None = None,
) -> WorkingMemory:
"""Get or create working memory for a session."""
if session_id not in self._working:
self._working[session_id] = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id) if project_id else None,
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
return self._working[session_id]
"""Get or create working memory for a session with LRU eviction."""
if session_id in self._working:
# Move to end (most recently used)
self._working.move_to_end(session_id)
return self._working[session_id]
# Evict oldest entries if at capacity
while len(self._working) >= self.MAX_WORKING_SESSIONS:
oldest_id, oldest_memory = self._working.popitem(last=False)
try:
await oldest_memory.close()
except Exception as e:
logger.warning(f"Error closing evicted working memory {oldest_id}: {e}")
# Create new working memory
working = await WorkingMemory.for_session(
session_id=session_id,
project_id=str(project_id) if project_id else None,
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
)
self._working[session_id] = working
return working
async def _get_episodic(self) -> EpisodicMemory:
"""Get or create episodic memory service."""

View File

@@ -7,7 +7,7 @@ Collects and exposes metrics for the memory system.
import asyncio
import logging
from collections import Counter, defaultdict
from collections import Counter, defaultdict, deque
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
@@ -57,11 +57,17 @@ class MemoryMetrics:
- Embedding operations
"""
# Maximum samples to keep in histogram (circular buffer)
MAX_HISTOGRAM_SAMPLES = 10000
def __init__(self) -> None:
"""Initialize MemoryMetrics."""
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
self._histograms: dict[str, list[float]] = defaultdict(list)
# Use deque with maxlen for bounded memory (circular buffer)
self._histograms: dict[str, deque[float]] = defaultdict(
lambda: deque(maxlen=self.MAX_HISTOGRAM_SAMPLES)
)
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
self._lock = asyncio.Lock()

View File

@@ -7,7 +7,6 @@ Implements pattern detection, success/failure analysis, anomaly detection,
and insight generation.
"""
import asyncio
import logging
import statistics
from collections import Counter, defaultdict
@@ -1426,36 +1425,27 @@ class MemoryReflection:
)
# Singleton instance with async-safe initialization
_memory_reflection: MemoryReflection | None = None
_reflection_lock = asyncio.Lock()
# Factory function - no singleton to avoid stale session issues
async def get_memory_reflection(
session: AsyncSession,
config: ReflectionConfig | None = None,
) -> MemoryReflection:
"""
Get or create the memory reflection service (async-safe).
Create a memory reflection service for the given session.
Note: This creates a new instance each time to avoid stale session issues.
The service is lightweight and safe to recreate per-request.
Args:
session: Database session
session: Database session (must be active)
config: Optional configuration
Returns:
MemoryReflection instance
"""
global _memory_reflection
if _memory_reflection is None:
async with _reflection_lock:
# Double-check locking pattern
if _memory_reflection is None:
_memory_reflection = MemoryReflection(session=session, config=config)
return _memory_reflection
return MemoryReflection(session=session, config=config)
async def reset_memory_reflection() -> None:
"""Reset the memory reflection singleton (async-safe)."""
global _memory_reflection
async with _reflection_lock:
_memory_reflection = None
"""No-op for backwards compatibility (singleton pattern removed)."""
return

View File

@@ -423,7 +423,8 @@ class WorkingMemory:
Returns:
Checkpoint ID for later restoration
"""
checkpoint_id = str(uuid.uuid4())[:8]
# Use full UUID to avoid collision risk (8 chars has ~50k collision at birthday paradox)
checkpoint_id = str(uuid.uuid4())
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
# Capture all current state

View File

@@ -738,26 +738,32 @@ class TestComprehensiveReflection:
assert "Episodes analyzed" in summary
class TestSingleton:
"""Tests for singleton pattern."""
class TestFactoryFunction:
"""Tests for factory function behavior.
async def test_get_memory_reflection_returns_singleton(
Note: The singleton pattern was removed to avoid stale database session bugs.
Each call now creates a fresh instance, which is safer for request-scoped usage.
"""
async def test_get_memory_reflection_creates_new_instance(
self,
mock_session: MagicMock,
) -> None:
"""Should return same instance."""
"""Should create new instance each call (no singleton for session safety)."""
r1 = await get_memory_reflection(mock_session)
r2 = await get_memory_reflection(mock_session)
assert r1 is r2
async def test_reset_creates_new_instance(
self,
mock_session: MagicMock,
) -> None:
"""Should create new instance after reset."""
r1 = await get_memory_reflection(mock_session)
await reset_memory_reflection()
r2 = await get_memory_reflection(mock_session)
# Different instances to avoid stale session issues
assert r1 is not r2
async def test_reset_is_no_op(
self,
mock_session: MagicMock,
) -> None:
"""Reset should be a no-op (kept for API compatibility)."""
r1 = await get_memory_reflection(mock_session)
await reset_memory_reflection() # Should not raise
r2 = await get_memory_reflection(mock_session)
# Still creates new instances (reset is no-op now)
assert r1 is not r2

View File

@@ -276,7 +276,7 @@ class TestWorkingMemoryCheckpoints:
checkpoint_id = await memory.create_checkpoint("Test checkpoint")
assert checkpoint_id is not None
assert len(checkpoint_id) == 8 # UUID prefix
assert len(checkpoint_id) == 36 # Full UUID for collision safety
@pytest.mark.asyncio
async def test_restore_checkpoint(self, memory: WorkingMemory) -> None: