forked from cardosofelipe/fast-next-template
Bug Fixes: - Remove singleton pattern from consolidation/reflection services to prevent stale database session bugs (session is now passed per-request) - Add LRU eviction to MemoryToolService._working dict (max 1000 sessions) to prevent unbounded memory growth - Replace O(n) list.remove() with O(1) OrderedDict.move_to_end() in RetrievalCache for better performance under load - Use deque with maxlen for metrics histograms to prevent unbounded memory growth (circular buffer with 10k max samples) - Use full UUID for checkpoint IDs instead of 8-char prefix to avoid collision risk at scale (birthday paradox at ~50k checkpoints) Test Updates: - Update checkpoint test to expect 36-char UUID - Update reflection singleton tests to expect new factory behavior - Add reset_memory_reflection() no-op for backwards compatibility 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
545 lines
17 KiB
Python
545 lines
17 KiB
Python
# app/services/memory/working/memory.py
|
|
"""
|
|
Working Memory Implementation.
|
|
|
|
Provides session-scoped ephemeral memory with:
|
|
- Key-value storage with TTL
|
|
- Task state tracking
|
|
- Scratchpad for reasoning steps
|
|
- Checkpoint/snapshot support
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from dataclasses import asdict
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
from app.services.memory.config import get_memory_settings
|
|
from app.services.memory.exceptions import (
|
|
MemoryConnectionError,
|
|
MemoryNotFoundError,
|
|
)
|
|
from app.services.memory.types import ScopeContext, ScopeLevel, TaskState
|
|
|
|
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Reserved key prefixes for internal use
|
|
_TASK_STATE_KEY = "_task_state"
|
|
_SCRATCHPAD_KEY = "_scratchpad"
|
|
_CHECKPOINT_PREFIX = "_checkpoint:"
|
|
_METADATA_KEY = "_metadata"
|
|
|
|
|
|
class WorkingMemory:
|
|
"""
|
|
Session-scoped working memory.
|
|
|
|
Provides ephemeral storage for agent's current task context:
|
|
- Variables and intermediate data
|
|
- Task state (current step, status, progress)
|
|
- Scratchpad for reasoning steps
|
|
- Checkpoints for recovery
|
|
|
|
Uses Redis as primary storage with in-memory fallback.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
scope: ScopeContext,
|
|
storage: WorkingMemoryStorage,
|
|
default_ttl_seconds: int | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize working memory for a scope.
|
|
|
|
Args:
|
|
scope: The scope context (session, agent instance, etc.)
|
|
storage: Storage backend (use create() factory for auto-configuration)
|
|
default_ttl_seconds: Default TTL for keys (None = no expiration)
|
|
"""
|
|
self._scope = scope
|
|
self._storage: WorkingMemoryStorage = storage
|
|
self._default_ttl = default_ttl_seconds
|
|
self._using_fallback = False
|
|
self._initialized = False
|
|
|
|
@classmethod
|
|
async def create(
|
|
cls,
|
|
scope: ScopeContext,
|
|
default_ttl_seconds: int | None = None,
|
|
) -> "WorkingMemory":
|
|
"""
|
|
Factory method to create WorkingMemory with auto-configured storage.
|
|
|
|
Attempts Redis first, falls back to in-memory if unavailable.
|
|
"""
|
|
settings = get_memory_settings()
|
|
key_prefix = f"wm:{scope.to_key_prefix()}:"
|
|
storage: WorkingMemoryStorage
|
|
|
|
# Try Redis first
|
|
if settings.working_memory_backend == "redis":
|
|
redis_storage = RedisStorage(key_prefix=key_prefix)
|
|
try:
|
|
if await redis_storage.is_healthy():
|
|
logger.debug(f"Using Redis storage for scope {scope.scope_id}")
|
|
instance = cls(
|
|
scope=scope,
|
|
storage=redis_storage,
|
|
default_ttl_seconds=default_ttl_seconds
|
|
or settings.working_memory_default_ttl_seconds,
|
|
)
|
|
await instance._initialize()
|
|
return instance
|
|
except MemoryConnectionError:
|
|
logger.warning("Redis unavailable, falling back to in-memory storage")
|
|
await redis_storage.close()
|
|
|
|
# Fall back to in-memory
|
|
storage = InMemoryStorage(
|
|
max_keys=settings.working_memory_max_items_per_session
|
|
)
|
|
instance = cls(
|
|
scope=scope,
|
|
storage=storage,
|
|
default_ttl_seconds=default_ttl_seconds
|
|
or settings.working_memory_default_ttl_seconds,
|
|
)
|
|
instance._using_fallback = True
|
|
await instance._initialize()
|
|
return instance
|
|
|
|
@classmethod
|
|
async def for_session(
|
|
cls,
|
|
session_id: str,
|
|
project_id: str | None = None,
|
|
agent_instance_id: str | None = None,
|
|
) -> "WorkingMemory":
|
|
"""
|
|
Convenience factory for session-scoped working memory.
|
|
|
|
Args:
|
|
session_id: Unique session identifier
|
|
project_id: Optional project context
|
|
agent_instance_id: Optional agent instance context
|
|
"""
|
|
# Build scope hierarchy
|
|
parent = None
|
|
if project_id:
|
|
parent = ScopeContext(
|
|
scope_type=ScopeLevel.PROJECT,
|
|
scope_id=project_id,
|
|
)
|
|
if agent_instance_id:
|
|
parent = ScopeContext(
|
|
scope_type=ScopeLevel.AGENT_INSTANCE,
|
|
scope_id=agent_instance_id,
|
|
parent=parent,
|
|
)
|
|
|
|
scope = ScopeContext(
|
|
scope_type=ScopeLevel.SESSION,
|
|
scope_id=session_id,
|
|
parent=parent,
|
|
)
|
|
|
|
return await cls.create(scope=scope)
|
|
|
|
async def _initialize(self) -> None:
|
|
"""Initialize working memory metadata."""
|
|
if self._initialized:
|
|
return
|
|
|
|
metadata = {
|
|
"scope_type": self._scope.scope_type.value,
|
|
"scope_id": self._scope.scope_id,
|
|
"created_at": datetime.now(UTC).isoformat(),
|
|
"using_fallback": self._using_fallback,
|
|
}
|
|
await self._storage.set(_METADATA_KEY, metadata)
|
|
self._initialized = True
|
|
|
|
@property
|
|
def scope(self) -> ScopeContext:
|
|
"""Get the scope context."""
|
|
return self._scope
|
|
|
|
@property
|
|
def is_using_fallback(self) -> bool:
|
|
"""Check if using fallback in-memory storage."""
|
|
return self._using_fallback
|
|
|
|
# =========================================================================
|
|
# Basic Key-Value Operations
|
|
# =========================================================================
|
|
|
|
async def set(
|
|
self,
|
|
key: str,
|
|
value: Any,
|
|
ttl_seconds: int | None = None,
|
|
) -> None:
|
|
"""
|
|
Store a value.
|
|
|
|
Args:
|
|
key: The key to store under
|
|
value: The value to store (must be JSON-serializable)
|
|
ttl_seconds: Optional TTL (uses default if not specified)
|
|
"""
|
|
if key.startswith("_"):
|
|
raise ValueError("Keys starting with '_' are reserved for internal use")
|
|
|
|
ttl = ttl_seconds if ttl_seconds is not None else self._default_ttl
|
|
await self._storage.set(key, value, ttl)
|
|
|
|
async def get(self, key: str, default: Any = None) -> Any:
|
|
"""
|
|
Get a value.
|
|
|
|
Args:
|
|
key: The key to retrieve
|
|
default: Default value if key not found
|
|
|
|
Returns:
|
|
The stored value or default
|
|
"""
|
|
result = await self._storage.get(key)
|
|
return result if result is not None else default
|
|
|
|
async def delete(self, key: str) -> bool:
|
|
"""
|
|
Delete a key.
|
|
|
|
Args:
|
|
key: The key to delete
|
|
|
|
Returns:
|
|
True if the key existed
|
|
"""
|
|
if key.startswith("_"):
|
|
raise ValueError("Cannot delete internal keys directly")
|
|
return await self._storage.delete(key)
|
|
|
|
async def exists(self, key: str) -> bool:
|
|
"""
|
|
Check if a key exists.
|
|
|
|
Args:
|
|
key: The key to check
|
|
|
|
Returns:
|
|
True if the key exists
|
|
"""
|
|
return await self._storage.exists(key)
|
|
|
|
async def list_keys(self, pattern: str = "*") -> list[str]:
|
|
"""
|
|
List keys matching a pattern.
|
|
|
|
Args:
|
|
pattern: Glob-style pattern (default "*" for all)
|
|
|
|
Returns:
|
|
List of matching keys (excludes internal keys)
|
|
"""
|
|
all_keys = await self._storage.list_keys(pattern)
|
|
return [k for k in all_keys if not k.startswith("_")]
|
|
|
|
async def get_all(self) -> dict[str, Any]:
|
|
"""
|
|
Get all user key-value pairs.
|
|
|
|
Returns:
|
|
Dictionary of all key-value pairs (excludes internal keys)
|
|
"""
|
|
all_data = await self._storage.get_all()
|
|
return {k: v for k, v in all_data.items() if not k.startswith("_")}
|
|
|
|
async def clear(self) -> int:
|
|
"""
|
|
Clear all user keys (preserves internal state).
|
|
|
|
Returns:
|
|
Number of keys deleted
|
|
"""
|
|
# Save internal state
|
|
task_state = await self._storage.get(_TASK_STATE_KEY)
|
|
scratchpad = await self._storage.get(_SCRATCHPAD_KEY)
|
|
metadata = await self._storage.get(_METADATA_KEY)
|
|
|
|
count = await self._storage.clear()
|
|
|
|
# Restore internal state
|
|
if metadata is not None:
|
|
await self._storage.set(_METADATA_KEY, metadata)
|
|
if task_state is not None:
|
|
await self._storage.set(_TASK_STATE_KEY, task_state)
|
|
if scratchpad is not None:
|
|
await self._storage.set(_SCRATCHPAD_KEY, scratchpad)
|
|
|
|
# Adjust count for preserved keys
|
|
preserved = sum(1 for x in [task_state, scratchpad, metadata] if x is not None)
|
|
return max(0, count - preserved)
|
|
|
|
# =========================================================================
|
|
# Task State Operations
|
|
# =========================================================================
|
|
|
|
async def set_task_state(self, state: TaskState) -> None:
|
|
"""
|
|
Set the current task state.
|
|
|
|
Args:
|
|
state: The task state to store
|
|
"""
|
|
state.updated_at = datetime.now(UTC)
|
|
await self._storage.set(_TASK_STATE_KEY, asdict(state))
|
|
|
|
async def get_task_state(self) -> TaskState | None:
|
|
"""
|
|
Get the current task state.
|
|
|
|
Returns:
|
|
The current TaskState or None if not set
|
|
"""
|
|
data = await self._storage.get(_TASK_STATE_KEY)
|
|
if data is None:
|
|
return None
|
|
|
|
# Convert datetime strings back to datetime objects
|
|
if isinstance(data.get("started_at"), str):
|
|
data["started_at"] = datetime.fromisoformat(data["started_at"])
|
|
if isinstance(data.get("updated_at"), str):
|
|
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
|
|
|
|
return TaskState(**data)
|
|
|
|
async def update_task_progress(
|
|
self,
|
|
current_step: int | None = None,
|
|
progress_percent: float | None = None,
|
|
status: str | None = None,
|
|
) -> TaskState | None:
|
|
"""
|
|
Update task progress fields.
|
|
|
|
Args:
|
|
current_step: New current step number
|
|
progress_percent: New progress percentage (0.0 to 100.0)
|
|
status: New status string
|
|
|
|
Returns:
|
|
Updated TaskState or None if no task state exists
|
|
"""
|
|
state = await self.get_task_state()
|
|
if state is None:
|
|
return None
|
|
|
|
if current_step is not None:
|
|
state.current_step = current_step
|
|
if progress_percent is not None:
|
|
state.progress_percent = min(100.0, max(0.0, progress_percent))
|
|
if status is not None:
|
|
state.status = status
|
|
|
|
await self.set_task_state(state)
|
|
return state
|
|
|
|
# =========================================================================
|
|
# Scratchpad Operations
|
|
# =========================================================================
|
|
|
|
async def append_scratchpad(self, content: str) -> None:
|
|
"""
|
|
Append content to the scratchpad.
|
|
|
|
Args:
|
|
content: Text to append
|
|
"""
|
|
settings = get_memory_settings()
|
|
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
|
|
|
# Check capacity
|
|
if len(entries) >= settings.working_memory_max_items_per_session:
|
|
# Remove oldest entries
|
|
entries = entries[-(settings.working_memory_max_items_per_session - 1) :]
|
|
|
|
entry = {
|
|
"content": content,
|
|
"timestamp": datetime.now(UTC).isoformat(),
|
|
}
|
|
entries.append(entry)
|
|
await self._storage.set(_SCRATCHPAD_KEY, entries)
|
|
|
|
async def get_scratchpad(self) -> list[str]:
|
|
"""
|
|
Get all scratchpad entries.
|
|
|
|
Returns:
|
|
List of scratchpad content strings (ordered by time)
|
|
"""
|
|
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
|
return [e["content"] for e in entries]
|
|
|
|
async def get_scratchpad_with_timestamps(self) -> list[dict[str, Any]]:
|
|
"""
|
|
Get all scratchpad entries with timestamps.
|
|
|
|
Returns:
|
|
List of dicts with 'content' and 'timestamp' keys
|
|
"""
|
|
return await self._storage.get(_SCRATCHPAD_KEY) or []
|
|
|
|
async def clear_scratchpad(self) -> int:
|
|
"""
|
|
Clear the scratchpad.
|
|
|
|
Returns:
|
|
Number of entries cleared
|
|
"""
|
|
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
|
count = len(entries)
|
|
await self._storage.set(_SCRATCHPAD_KEY, [])
|
|
return count
|
|
|
|
# =========================================================================
|
|
# Checkpoint Operations
|
|
# =========================================================================
|
|
|
|
async def create_checkpoint(self, description: str = "") -> str:
|
|
"""
|
|
Create a checkpoint of current state.
|
|
|
|
Args:
|
|
description: Optional description of the checkpoint
|
|
|
|
Returns:
|
|
Checkpoint ID for later restoration
|
|
"""
|
|
# Use full UUID to avoid collision risk (8 chars has ~50k collision at birthday paradox)
|
|
checkpoint_id = str(uuid.uuid4())
|
|
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
|
|
|
# Capture all current state
|
|
all_data = await self._storage.get_all()
|
|
|
|
checkpoint = {
|
|
"id": checkpoint_id,
|
|
"description": description,
|
|
"created_at": datetime.now(UTC).isoformat(),
|
|
"data": all_data,
|
|
}
|
|
|
|
await self._storage.set(checkpoint_key, checkpoint)
|
|
logger.debug(f"Created checkpoint {checkpoint_id}")
|
|
return checkpoint_id
|
|
|
|
async def restore_checkpoint(self, checkpoint_id: str) -> None:
|
|
"""
|
|
Restore state from a checkpoint.
|
|
|
|
Args:
|
|
checkpoint_id: ID of the checkpoint to restore
|
|
|
|
Raises:
|
|
MemoryNotFoundError: If checkpoint not found
|
|
"""
|
|
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
|
checkpoint = await self._storage.get(checkpoint_key)
|
|
|
|
if checkpoint is None:
|
|
raise MemoryNotFoundError(f"Checkpoint {checkpoint_id} not found")
|
|
|
|
# Clear current state
|
|
await self._storage.clear()
|
|
|
|
# Restore all data from checkpoint
|
|
for key, value in checkpoint["data"].items():
|
|
await self._storage.set(key, value)
|
|
|
|
# Keep the checkpoint itself
|
|
await self._storage.set(checkpoint_key, checkpoint)
|
|
|
|
logger.debug(f"Restored checkpoint {checkpoint_id}")
|
|
|
|
async def list_checkpoints(self) -> list[dict[str, Any]]:
|
|
"""
|
|
List all available checkpoints.
|
|
|
|
Returns:
|
|
List of checkpoint metadata (id, description, created_at)
|
|
"""
|
|
checkpoint_keys = await self._storage.list_keys(f"{_CHECKPOINT_PREFIX}*")
|
|
checkpoints = []
|
|
|
|
for key in checkpoint_keys:
|
|
data = await self._storage.get(key)
|
|
if data:
|
|
checkpoints.append(
|
|
{
|
|
"id": data["id"],
|
|
"description": data["description"],
|
|
"created_at": data["created_at"],
|
|
}
|
|
)
|
|
|
|
# Sort by creation time
|
|
checkpoints.sort(key=lambda x: x["created_at"])
|
|
return checkpoints
|
|
|
|
async def delete_checkpoint(self, checkpoint_id: str) -> bool:
|
|
"""
|
|
Delete a checkpoint.
|
|
|
|
Args:
|
|
checkpoint_id: ID of the checkpoint to delete
|
|
|
|
Returns:
|
|
True if checkpoint existed
|
|
"""
|
|
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
|
return await self._storage.delete(checkpoint_key)
|
|
|
|
# =========================================================================
|
|
# Health and Lifecycle
|
|
# =========================================================================
|
|
|
|
async def is_healthy(self) -> bool:
|
|
"""Check if the working memory storage is healthy."""
|
|
return await self._storage.is_healthy()
|
|
|
|
async def close(self) -> None:
|
|
"""Close the working memory storage."""
|
|
if self._storage:
|
|
await self._storage.close()
|
|
|
|
async def get_stats(self) -> dict[str, Any]:
|
|
"""
|
|
Get working memory statistics.
|
|
|
|
Returns:
|
|
Dictionary with stats about current state
|
|
"""
|
|
all_keys = await self._storage.list_keys("*")
|
|
user_keys = [k for k in all_keys if not k.startswith("_")]
|
|
checkpoint_keys = [k for k in all_keys if k.startswith(_CHECKPOINT_PREFIX)]
|
|
scratchpad = await self._storage.get(_SCRATCHPAD_KEY) or []
|
|
|
|
return {
|
|
"scope_type": self._scope.scope_type.value,
|
|
"scope_id": self._scope.scope_id,
|
|
"using_fallback": self._using_fallback,
|
|
"total_keys": len(all_keys),
|
|
"user_keys": len(user_keys),
|
|
"checkpoint_count": len(checkpoint_keys),
|
|
"scratchpad_entries": len(scratchpad),
|
|
"has_task_state": await self._storage.exists(_TASK_STATE_KEY),
|
|
}
|