forked from cardosofelipe/fast-next-template
fix(memory): address critical bugs from multi-agent review
Bug Fixes: - Remove singleton pattern from consolidation/reflection services to prevent stale database session bugs (session is now passed per-request) - Add LRU eviction to MemoryToolService._working dict (max 1000 sessions) to prevent unbounded memory growth - Replace O(n) list.remove() with O(1) OrderedDict.move_to_end() in RetrievalCache for better performance under load - Use deque with maxlen for metrics histograms to prevent unbounded memory growth (circular buffer with 10k max samples) - Use full UUID for checkpoint IDs instead of 8-char prefix to avoid collision risk at scale (birthday paradox at ~50k checkpoints) Test Updates: - Update checkpoint test to expect 36-char UUID - Update reflection singleton tests to expect new factory behavior - Add reset_memory_reflection() no-op for backwards compatibility 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -892,27 +892,22 @@ class MemoryConsolidationService:
|
||||
return result
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_consolidation_service: MemoryConsolidationService | None = None
|
||||
|
||||
|
||||
# Factory function - no singleton to avoid stale session issues
|
||||
async def get_consolidation_service(
|
||||
session: AsyncSession,
|
||||
config: ConsolidationConfig | None = None,
|
||||
) -> MemoryConsolidationService:
|
||||
"""
|
||||
Get or create the memory consolidation service.
|
||||
Create a memory consolidation service for the given session.
|
||||
|
||||
Note: This creates a new instance each time to avoid stale session issues.
|
||||
The service is lightweight and safe to recreate per-request.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
session: Database session (must be active)
|
||||
config: Optional configuration
|
||||
|
||||
Returns:
|
||||
MemoryConsolidationService instance
|
||||
"""
|
||||
global _consolidation_service
|
||||
if _consolidation_service is None:
|
||||
_consolidation_service = MemoryConsolidationService(
|
||||
session=session, config=config
|
||||
)
|
||||
return _consolidation_service
|
||||
return MemoryConsolidationService(session=session, config=config)
|
||||
|
||||
@@ -13,6 +13,7 @@ Provides hybrid retrieval capabilities combining:
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, TypeVar
|
||||
@@ -243,7 +244,8 @@ class RetrievalCache:
|
||||
"""
|
||||
In-memory cache for retrieval results.
|
||||
|
||||
Supports TTL-based expiration and LRU eviction.
|
||||
Supports TTL-based expiration and LRU eviction with O(1) operations.
|
||||
Uses OrderedDict for efficient LRU tracking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -258,10 +260,10 @@ class RetrievalCache:
|
||||
max_entries: Maximum cache entries
|
||||
default_ttl_seconds: Default TTL for entries
|
||||
"""
|
||||
self._cache: dict[str, CacheEntry] = {}
|
||||
# OrderedDict maintains insertion order; we use move_to_end for O(1) LRU
|
||||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||||
self._max_entries = max_entries
|
||||
self._default_ttl = default_ttl_seconds
|
||||
self._access_order: list[str] = []
|
||||
logger.info(
|
||||
f"Initialized RetrievalCache with max_entries={max_entries}, "
|
||||
f"ttl={default_ttl_seconds}s"
|
||||
@@ -283,14 +285,10 @@ class RetrievalCache:
|
||||
entry = self._cache[query_key]
|
||||
if entry.is_expired():
|
||||
del self._cache[query_key]
|
||||
if query_key in self._access_order:
|
||||
self._access_order.remove(query_key)
|
||||
return None
|
||||
|
||||
# Update access order (LRU)
|
||||
if query_key in self._access_order:
|
||||
self._access_order.remove(query_key)
|
||||
self._access_order.append(query_key)
|
||||
# Update access order (LRU) - O(1) with OrderedDict
|
||||
self._cache.move_to_end(query_key)
|
||||
|
||||
logger.debug(f"Cache hit for {query_key}")
|
||||
return entry.results
|
||||
@@ -309,11 +307,9 @@ class RetrievalCache:
|
||||
results: Results to cache
|
||||
ttl_seconds: TTL for this entry (or default)
|
||||
"""
|
||||
# Evict if at capacity
|
||||
while len(self._cache) >= self._max_entries and self._access_order:
|
||||
oldest_key = self._access_order.pop(0)
|
||||
if oldest_key in self._cache:
|
||||
del self._cache[oldest_key]
|
||||
# Evict oldest entries if at capacity - O(1) with popitem(last=False)
|
||||
while len(self._cache) >= self._max_entries:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
entry = CacheEntry(
|
||||
results=results,
|
||||
@@ -323,7 +319,6 @@ class RetrievalCache:
|
||||
)
|
||||
|
||||
self._cache[query_key] = entry
|
||||
self._access_order.append(query_key)
|
||||
logger.debug(f"Cached {len(results)} results for {query_key}")
|
||||
|
||||
def invalidate(self, query_key: str) -> bool:
|
||||
@@ -338,8 +333,6 @@ class RetrievalCache:
|
||||
"""
|
||||
if query_key in self._cache:
|
||||
del self._cache[query_key]
|
||||
if query_key in self._access_order:
|
||||
self._access_order.remove(query_key)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -376,7 +369,6 @@ class RetrievalCache:
|
||||
"""
|
||||
count = len(self._cache)
|
||||
self._cache.clear()
|
||||
self._access_order.clear()
|
||||
logger.info(f"Cleared {count} cache entries")
|
||||
return count
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ All tools are scoped to project/agent context for proper isolation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
@@ -83,6 +84,9 @@ class MemoryToolService:
|
||||
This service coordinates between different memory types.
|
||||
"""
|
||||
|
||||
# Maximum number of working memory sessions to cache (LRU eviction)
|
||||
MAX_WORKING_SESSIONS = 1000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
@@ -98,8 +102,8 @@ class MemoryToolService:
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
|
||||
# Lazy-initialized memory services
|
||||
self._working: dict[str, WorkingMemory] = {} # keyed by session_id
|
||||
# Lazy-initialized memory services with LRU eviction for working memory
|
||||
self._working: OrderedDict[str, WorkingMemory] = OrderedDict()
|
||||
self._episodic: EpisodicMemory | None = None
|
||||
self._semantic: SemanticMemory | None = None
|
||||
self._procedural: ProceduralMemory | None = None
|
||||
@@ -110,14 +114,28 @@ class MemoryToolService:
|
||||
project_id: UUID | None = None,
|
||||
agent_instance_id: UUID | None = None,
|
||||
) -> WorkingMemory:
|
||||
"""Get or create working memory for a session."""
|
||||
if session_id not in self._working:
|
||||
self._working[session_id] = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id) if project_id else None,
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
return self._working[session_id]
|
||||
"""Get or create working memory for a session with LRU eviction."""
|
||||
if session_id in self._working:
|
||||
# Move to end (most recently used)
|
||||
self._working.move_to_end(session_id)
|
||||
return self._working[session_id]
|
||||
|
||||
# Evict oldest entries if at capacity
|
||||
while len(self._working) >= self.MAX_WORKING_SESSIONS:
|
||||
oldest_id, oldest_memory = self._working.popitem(last=False)
|
||||
try:
|
||||
await oldest_memory.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing evicted working memory {oldest_id}: {e}")
|
||||
|
||||
# Create new working memory
|
||||
working = await WorkingMemory.for_session(
|
||||
session_id=session_id,
|
||||
project_id=str(project_id) if project_id else None,
|
||||
agent_instance_id=str(agent_instance_id) if agent_instance_id else None,
|
||||
)
|
||||
self._working[session_id] = working
|
||||
return working
|
||||
|
||||
async def _get_episodic(self) -> EpisodicMemory:
|
||||
"""Get or create episodic memory service."""
|
||||
|
||||
@@ -7,7 +7,7 @@ Collects and exposes metrics for the memory system.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import Counter, defaultdict
|
||||
from collections import Counter, defaultdict, deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
@@ -57,11 +57,17 @@ class MemoryMetrics:
|
||||
- Embedding operations
|
||||
"""
|
||||
|
||||
# Maximum samples to keep in histogram (circular buffer)
|
||||
MAX_HISTOGRAM_SAMPLES = 10000
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize MemoryMetrics."""
|
||||
self._counters: dict[str, Counter[str]] = defaultdict(Counter)
|
||||
self._gauges: dict[str, dict[str, float]] = defaultdict(dict)
|
||||
self._histograms: dict[str, list[float]] = defaultdict(list)
|
||||
# Use deque with maxlen for bounded memory (circular buffer)
|
||||
self._histograms: dict[str, deque[float]] = defaultdict(
|
||||
lambda: deque(maxlen=self.MAX_HISTOGRAM_SAMPLES)
|
||||
)
|
||||
self._histogram_buckets: dict[str, list[HistogramBucket]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ Implements pattern detection, success/failure analysis, anomaly detection,
|
||||
and insight generation.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import statistics
|
||||
from collections import Counter, defaultdict
|
||||
@@ -1426,36 +1425,27 @@ class MemoryReflection:
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance with async-safe initialization
|
||||
_memory_reflection: MemoryReflection | None = None
|
||||
_reflection_lock = asyncio.Lock()
|
||||
|
||||
|
||||
# Factory function - no singleton to avoid stale session issues
|
||||
async def get_memory_reflection(
|
||||
session: AsyncSession,
|
||||
config: ReflectionConfig | None = None,
|
||||
) -> MemoryReflection:
|
||||
"""
|
||||
Get or create the memory reflection service (async-safe).
|
||||
Create a memory reflection service for the given session.
|
||||
|
||||
Note: This creates a new instance each time to avoid stale session issues.
|
||||
The service is lightweight and safe to recreate per-request.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
session: Database session (must be active)
|
||||
config: Optional configuration
|
||||
|
||||
Returns:
|
||||
MemoryReflection instance
|
||||
"""
|
||||
global _memory_reflection
|
||||
if _memory_reflection is None:
|
||||
async with _reflection_lock:
|
||||
# Double-check locking pattern
|
||||
if _memory_reflection is None:
|
||||
_memory_reflection = MemoryReflection(session=session, config=config)
|
||||
return _memory_reflection
|
||||
return MemoryReflection(session=session, config=config)
|
||||
|
||||
|
||||
async def reset_memory_reflection() -> None:
|
||||
"""Reset the memory reflection singleton (async-safe)."""
|
||||
global _memory_reflection
|
||||
async with _reflection_lock:
|
||||
_memory_reflection = None
|
||||
"""No-op for backwards compatibility (singleton pattern removed)."""
|
||||
return
|
||||
|
||||
@@ -423,7 +423,8 @@ class WorkingMemory:
|
||||
Returns:
|
||||
Checkpoint ID for later restoration
|
||||
"""
|
||||
checkpoint_id = str(uuid.uuid4())[:8]
|
||||
# Use full UUID to avoid collision risk (8 chars has ~50k collision at birthday paradox)
|
||||
checkpoint_id = str(uuid.uuid4())
|
||||
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||
|
||||
# Capture all current state
|
||||
|
||||
@@ -738,26 +738,32 @@ class TestComprehensiveReflection:
|
||||
assert "Episodes analyzed" in summary
|
||||
|
||||
|
||||
class TestSingleton:
|
||||
"""Tests for singleton pattern."""
|
||||
class TestFactoryFunction:
|
||||
"""Tests for factory function behavior.
|
||||
|
||||
async def test_get_memory_reflection_returns_singleton(
|
||||
Note: The singleton pattern was removed to avoid stale database session bugs.
|
||||
Each call now creates a fresh instance, which is safer for request-scoped usage.
|
||||
"""
|
||||
|
||||
async def test_get_memory_reflection_creates_new_instance(
|
||||
self,
|
||||
mock_session: MagicMock,
|
||||
) -> None:
|
||||
"""Should return same instance."""
|
||||
"""Should create new instance each call (no singleton for session safety)."""
|
||||
r1 = await get_memory_reflection(mock_session)
|
||||
r2 = await get_memory_reflection(mock_session)
|
||||
|
||||
assert r1 is r2
|
||||
|
||||
async def test_reset_creates_new_instance(
|
||||
self,
|
||||
mock_session: MagicMock,
|
||||
) -> None:
|
||||
"""Should create new instance after reset."""
|
||||
r1 = await get_memory_reflection(mock_session)
|
||||
await reset_memory_reflection()
|
||||
r2 = await get_memory_reflection(mock_session)
|
||||
|
||||
# Different instances to avoid stale session issues
|
||||
assert r1 is not r2
|
||||
|
||||
async def test_reset_is_no_op(
|
||||
self,
|
||||
mock_session: MagicMock,
|
||||
) -> None:
|
||||
"""Reset should be a no-op (kept for API compatibility)."""
|
||||
r1 = await get_memory_reflection(mock_session)
|
||||
await reset_memory_reflection() # Should not raise
|
||||
r2 = await get_memory_reflection(mock_session)
|
||||
|
||||
# Still creates new instances (reset is no-op now)
|
||||
assert r1 is not r2
|
||||
|
||||
@@ -276,7 +276,7 @@ class TestWorkingMemoryCheckpoints:
|
||||
checkpoint_id = await memory.create_checkpoint("Test checkpoint")
|
||||
|
||||
assert checkpoint_id is not None
|
||||
assert len(checkpoint_id) == 8 # UUID prefix
|
||||
assert len(checkpoint_id) == 36 # Full UUID for collision safety
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_restore_checkpoint(self, memory: WorkingMemory) -> None:
|
||||
|
||||
Reference in New Issue
Block a user