# app/services/memory/working/memory.py """ Working Memory Implementation. Provides session-scoped ephemeral memory with: - Key-value storage with TTL - Task state tracking - Scratchpad for reasoning steps - Checkpoint/snapshot support """ import logging import uuid from dataclasses import asdict from datetime import UTC, datetime from typing import Any from app.services.memory.config import get_memory_settings from app.services.memory.exceptions import ( MemoryConnectionError, MemoryNotFoundError, ) from app.services.memory.types import ScopeContext, ScopeLevel, TaskState from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage logger = logging.getLogger(__name__) # Reserved key prefixes for internal use _TASK_STATE_KEY = "_task_state" _SCRATCHPAD_KEY = "_scratchpad" _CHECKPOINT_PREFIX = "_checkpoint:" _METADATA_KEY = "_metadata" class WorkingMemory: """ Session-scoped working memory. Provides ephemeral storage for agent's current task context: - Variables and intermediate data - Task state (current step, status, progress) - Scratchpad for reasoning steps - Checkpoints for recovery Uses Redis as primary storage with in-memory fallback. """ def __init__( self, scope: ScopeContext, storage: WorkingMemoryStorage, default_ttl_seconds: int | None = None, ) -> None: """ Initialize working memory for a scope. Args: scope: The scope context (session, agent instance, etc.) storage: Storage backend (use create() factory for auto-configuration) default_ttl_seconds: Default TTL for keys (None = no expiration) """ self._scope = scope self._storage: WorkingMemoryStorage = storage self._default_ttl = default_ttl_seconds self._using_fallback = False self._initialized = False @classmethod async def create( cls, scope: ScopeContext, default_ttl_seconds: int | None = None, ) -> "WorkingMemory": """ Factory method to create WorkingMemory with auto-configured storage. Attempts Redis first, falls back to in-memory if unavailable. """ settings = get_memory_settings() key_prefix = f"wm:{scope.to_key_prefix()}:" storage: WorkingMemoryStorage # Try Redis first if settings.working_memory_backend == "redis": redis_storage = RedisStorage(key_prefix=key_prefix) try: if await redis_storage.is_healthy(): logger.debug(f"Using Redis storage for scope {scope.scope_id}") instance = cls( scope=scope, storage=redis_storage, default_ttl_seconds=default_ttl_seconds or settings.working_memory_default_ttl_seconds, ) await instance._initialize() return instance except MemoryConnectionError: logger.warning("Redis unavailable, falling back to in-memory storage") await redis_storage.close() # Fall back to in-memory storage = InMemoryStorage( max_keys=settings.working_memory_max_items_per_session ) instance = cls( scope=scope, storage=storage, default_ttl_seconds=default_ttl_seconds or settings.working_memory_default_ttl_seconds, ) instance._using_fallback = True await instance._initialize() return instance @classmethod async def for_session( cls, session_id: str, project_id: str | None = None, agent_instance_id: str | None = None, ) -> "WorkingMemory": """ Convenience factory for session-scoped working memory. Args: session_id: Unique session identifier project_id: Optional project context agent_instance_id: Optional agent instance context """ # Build scope hierarchy parent = None if project_id: parent = ScopeContext( scope_type=ScopeLevel.PROJECT, scope_id=project_id, ) if agent_instance_id: parent = ScopeContext( scope_type=ScopeLevel.AGENT_INSTANCE, scope_id=agent_instance_id, parent=parent, ) scope = ScopeContext( scope_type=ScopeLevel.SESSION, scope_id=session_id, parent=parent, ) return await cls.create(scope=scope) async def _initialize(self) -> None: """Initialize working memory metadata.""" if self._initialized: return metadata = { "scope_type": self._scope.scope_type.value, "scope_id": self._scope.scope_id, "created_at": datetime.now(UTC).isoformat(), "using_fallback": self._using_fallback, } await self._storage.set(_METADATA_KEY, metadata) self._initialized = True @property def scope(self) -> ScopeContext: """Get the scope context.""" return self._scope @property def is_using_fallback(self) -> bool: """Check if using fallback in-memory storage.""" return self._using_fallback # ========================================================================= # Basic Key-Value Operations # ========================================================================= async def set( self, key: str, value: Any, ttl_seconds: int | None = None, ) -> None: """ Store a value. Args: key: The key to store under value: The value to store (must be JSON-serializable) ttl_seconds: Optional TTL (uses default if not specified) """ if key.startswith("_"): raise ValueError("Keys starting with '_' are reserved for internal use") ttl = ttl_seconds if ttl_seconds is not None else self._default_ttl await self._storage.set(key, value, ttl) async def get(self, key: str, default: Any = None) -> Any: """ Get a value. Args: key: The key to retrieve default: Default value if key not found Returns: The stored value or default """ result = await self._storage.get(key) return result if result is not None else default async def delete(self, key: str) -> bool: """ Delete a key. Args: key: The key to delete Returns: True if the key existed """ if key.startswith("_"): raise ValueError("Cannot delete internal keys directly") return await self._storage.delete(key) async def exists(self, key: str) -> bool: """ Check if a key exists. Args: key: The key to check Returns: True if the key exists """ return await self._storage.exists(key) async def list_keys(self, pattern: str = "*") -> list[str]: """ List keys matching a pattern. Args: pattern: Glob-style pattern (default "*" for all) Returns: List of matching keys (excludes internal keys) """ all_keys = await self._storage.list_keys(pattern) return [k for k in all_keys if not k.startswith("_")] async def get_all(self) -> dict[str, Any]: """ Get all user key-value pairs. Returns: Dictionary of all key-value pairs (excludes internal keys) """ all_data = await self._storage.get_all() return {k: v for k, v in all_data.items() if not k.startswith("_")} async def clear(self) -> int: """ Clear all user keys (preserves internal state). Returns: Number of keys deleted """ # Save internal state task_state = await self._storage.get(_TASK_STATE_KEY) scratchpad = await self._storage.get(_SCRATCHPAD_KEY) metadata = await self._storage.get(_METADATA_KEY) count = await self._storage.clear() # Restore internal state if metadata is not None: await self._storage.set(_METADATA_KEY, metadata) if task_state is not None: await self._storage.set(_TASK_STATE_KEY, task_state) if scratchpad is not None: await self._storage.set(_SCRATCHPAD_KEY, scratchpad) # Adjust count for preserved keys preserved = sum(1 for x in [task_state, scratchpad, metadata] if x is not None) return max(0, count - preserved) # ========================================================================= # Task State Operations # ========================================================================= async def set_task_state(self, state: TaskState) -> None: """ Set the current task state. Args: state: The task state to store """ state.updated_at = datetime.now(UTC) await self._storage.set(_TASK_STATE_KEY, asdict(state)) async def get_task_state(self) -> TaskState | None: """ Get the current task state. Returns: The current TaskState or None if not set """ data = await self._storage.get(_TASK_STATE_KEY) if data is None: return None # Convert datetime strings back to datetime objects if isinstance(data.get("started_at"), str): data["started_at"] = datetime.fromisoformat(data["started_at"]) if isinstance(data.get("updated_at"), str): data["updated_at"] = datetime.fromisoformat(data["updated_at"]) return TaskState(**data) async def update_task_progress( self, current_step: int | None = None, progress_percent: float | None = None, status: str | None = None, ) -> TaskState | None: """ Update task progress fields. Args: current_step: New current step number progress_percent: New progress percentage (0.0 to 100.0) status: New status string Returns: Updated TaskState or None if no task state exists """ state = await self.get_task_state() if state is None: return None if current_step is not None: state.current_step = current_step if progress_percent is not None: state.progress_percent = min(100.0, max(0.0, progress_percent)) if status is not None: state.status = status await self.set_task_state(state) return state # ========================================================================= # Scratchpad Operations # ========================================================================= async def append_scratchpad(self, content: str) -> None: """ Append content to the scratchpad. Args: content: Text to append """ settings = get_memory_settings() entries = await self._storage.get(_SCRATCHPAD_KEY) or [] # Check capacity if len(entries) >= settings.working_memory_max_items_per_session: # Remove oldest entries entries = entries[-(settings.working_memory_max_items_per_session - 1) :] entry = { "content": content, "timestamp": datetime.now(UTC).isoformat(), } entries.append(entry) await self._storage.set(_SCRATCHPAD_KEY, entries) async def get_scratchpad(self) -> list[str]: """ Get all scratchpad entries. Returns: List of scratchpad content strings (ordered by time) """ entries = await self._storage.get(_SCRATCHPAD_KEY) or [] return [e["content"] for e in entries] async def get_scratchpad_with_timestamps(self) -> list[dict[str, Any]]: """ Get all scratchpad entries with timestamps. Returns: List of dicts with 'content' and 'timestamp' keys """ return await self._storage.get(_SCRATCHPAD_KEY) or [] async def clear_scratchpad(self) -> int: """ Clear the scratchpad. Returns: Number of entries cleared """ entries = await self._storage.get(_SCRATCHPAD_KEY) or [] count = len(entries) await self._storage.set(_SCRATCHPAD_KEY, []) return count # ========================================================================= # Checkpoint Operations # ========================================================================= async def create_checkpoint(self, description: str = "") -> str: """ Create a checkpoint of current state. Args: description: Optional description of the checkpoint Returns: Checkpoint ID for later restoration """ # Use full UUID to avoid collision risk (8 chars has ~50k collision at birthday paradox) checkpoint_id = str(uuid.uuid4()) checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}" # Capture all current state all_data = await self._storage.get_all() checkpoint = { "id": checkpoint_id, "description": description, "created_at": datetime.now(UTC).isoformat(), "data": all_data, } await self._storage.set(checkpoint_key, checkpoint) logger.debug(f"Created checkpoint {checkpoint_id}") return checkpoint_id async def restore_checkpoint(self, checkpoint_id: str) -> None: """ Restore state from a checkpoint. Args: checkpoint_id: ID of the checkpoint to restore Raises: MemoryNotFoundError: If checkpoint not found """ checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}" checkpoint = await self._storage.get(checkpoint_key) if checkpoint is None: raise MemoryNotFoundError(f"Checkpoint {checkpoint_id} not found") # Clear current state await self._storage.clear() # Restore all data from checkpoint for key, value in checkpoint["data"].items(): await self._storage.set(key, value) # Keep the checkpoint itself await self._storage.set(checkpoint_key, checkpoint) logger.debug(f"Restored checkpoint {checkpoint_id}") async def list_checkpoints(self) -> list[dict[str, Any]]: """ List all available checkpoints. Returns: List of checkpoint metadata (id, description, created_at) """ checkpoint_keys = await self._storage.list_keys(f"{_CHECKPOINT_PREFIX}*") checkpoints = [] for key in checkpoint_keys: data = await self._storage.get(key) if data: checkpoints.append( { "id": data["id"], "description": data["description"], "created_at": data["created_at"], } ) # Sort by creation time checkpoints.sort(key=lambda x: x["created_at"]) return checkpoints async def delete_checkpoint(self, checkpoint_id: str) -> bool: """ Delete a checkpoint. Args: checkpoint_id: ID of the checkpoint to delete Returns: True if checkpoint existed """ checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}" return await self._storage.delete(checkpoint_key) # ========================================================================= # Health and Lifecycle # ========================================================================= async def is_healthy(self) -> bool: """Check if the working memory storage is healthy.""" return await self._storage.is_healthy() async def close(self) -> None: """Close the working memory storage.""" if self._storage: await self._storage.close() async def get_stats(self) -> dict[str, Any]: """ Get working memory statistics. Returns: Dictionary with stats about current state """ all_keys = await self._storage.list_keys("*") user_keys = [k for k in all_keys if not k.startswith("_")] checkpoint_keys = [k for k in all_keys if k.startswith(_CHECKPOINT_PREFIX)] scratchpad = await self._storage.get(_SCRATCHPAD_KEY) or [] return { "scope_type": self._scope.scope_type.value, "scope_id": self._scope.scope_id, "using_fallback": self._using_fallback, "total_keys": len(all_keys), "user_keys": len(user_keys), "checkpoint_count": len(checkpoint_keys), "scratchpad_entries": len(scratchpad), "has_task_state": await self._storage.exists(_TASK_STATE_KEY), }