feat(memory): add working memory implementation (Issue #89)
Implements session-scoped ephemeral memory with: Storage Backends: - InMemoryStorage: Thread-safe fallback with TTL support and capacity limits - RedisStorage: Primary storage with connection pooling and JSON serialization - Auto-fallback from Redis to in-memory when unavailable WorkingMemory Class: - Key-value storage with TTL and reserved key protection - Task state tracking with progress updates - Scratchpad for reasoning steps with timestamps - Checkpoint/snapshot support for recovery - Factory methods for auto-configured storage Tests: - 55 unit tests covering all functionality - Tests for basic ops, TTL, capacity, concurrency - Tests for task state, scratchpad, checkpoints 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -94,6 +94,22 @@ class MemoryStorageError(MemoryError):
|
|||||||
self.backend = backend
|
self.backend = backend
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConnectionError(MemoryError):
|
||||||
|
"""Raised when memory storage connection fails."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Memory connection failed",
|
||||||
|
*,
|
||||||
|
backend: str | None = None,
|
||||||
|
host: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, **kwargs)
|
||||||
|
self.backend = backend
|
||||||
|
self.host = host
|
||||||
|
|
||||||
|
|
||||||
class MemorySerializationError(MemoryError):
|
class MemorySerializationError(MemoryError):
|
||||||
"""Raised when memory serialization/deserialization fails."""
|
"""Raised when memory serialization/deserialization fails."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
|
# app/services/memory/working/__init__.py
|
||||||
"""
|
"""
|
||||||
Working Memory
|
Working Memory Implementation.
|
||||||
|
|
||||||
Session-scoped ephemeral memory for current task state,
|
Provides short-term memory storage with Redis primary and in-memory fallback.
|
||||||
variables, and scratchpad.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Will be populated in #89
|
from .memory import WorkingMemory
|
||||||
|
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"InMemoryStorage",
|
||||||
|
"RedisStorage",
|
||||||
|
"WorkingMemory",
|
||||||
|
"WorkingMemoryStorage",
|
||||||
|
]
|
||||||
|
|||||||
543
backend/app/services/memory/working/memory.py
Normal file
543
backend/app/services/memory/working/memory.py
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
# app/services/memory/working/memory.py
|
||||||
|
"""
|
||||||
|
Working Memory Implementation.
|
||||||
|
|
||||||
|
Provides session-scoped ephemeral memory with:
|
||||||
|
- Key-value storage with TTL
|
||||||
|
- Task state tracking
|
||||||
|
- Scratchpad for reasoning steps
|
||||||
|
- Checkpoint/snapshot support
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from dataclasses import asdict
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.services.memory.config import get_memory_settings
|
||||||
|
from app.services.memory.exceptions import (
|
||||||
|
MemoryConnectionError,
|
||||||
|
MemoryNotFoundError,
|
||||||
|
)
|
||||||
|
from app.services.memory.types import ScopeContext, ScopeLevel, TaskState
|
||||||
|
|
||||||
|
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Reserved key prefixes for internal use
|
||||||
|
_TASK_STATE_KEY = "_task_state"
|
||||||
|
_SCRATCHPAD_KEY = "_scratchpad"
|
||||||
|
_CHECKPOINT_PREFIX = "_checkpoint:"
|
||||||
|
_METADATA_KEY = "_metadata"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkingMemory:
|
||||||
|
"""
|
||||||
|
Session-scoped working memory.
|
||||||
|
|
||||||
|
Provides ephemeral storage for agent's current task context:
|
||||||
|
- Variables and intermediate data
|
||||||
|
- Task state (current step, status, progress)
|
||||||
|
- Scratchpad for reasoning steps
|
||||||
|
- Checkpoints for recovery
|
||||||
|
|
||||||
|
Uses Redis as primary storage with in-memory fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scope: ScopeContext,
|
||||||
|
storage: WorkingMemoryStorage,
|
||||||
|
default_ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize working memory for a scope.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scope: The scope context (session, agent instance, etc.)
|
||||||
|
storage: Storage backend (use create() factory for auto-configuration)
|
||||||
|
default_ttl_seconds: Default TTL for keys (None = no expiration)
|
||||||
|
"""
|
||||||
|
self._scope = scope
|
||||||
|
self._storage: WorkingMemoryStorage = storage
|
||||||
|
self._default_ttl = default_ttl_seconds
|
||||||
|
self._using_fallback = False
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create(
|
||||||
|
cls,
|
||||||
|
scope: ScopeContext,
|
||||||
|
default_ttl_seconds: int | None = None,
|
||||||
|
) -> "WorkingMemory":
|
||||||
|
"""
|
||||||
|
Factory method to create WorkingMemory with auto-configured storage.
|
||||||
|
|
||||||
|
Attempts Redis first, falls back to in-memory if unavailable.
|
||||||
|
"""
|
||||||
|
settings = get_memory_settings()
|
||||||
|
key_prefix = f"wm:{scope.to_key_prefix()}:"
|
||||||
|
storage: WorkingMemoryStorage
|
||||||
|
|
||||||
|
# Try Redis first
|
||||||
|
if settings.working_memory_backend == "redis":
|
||||||
|
redis_storage = RedisStorage(key_prefix=key_prefix)
|
||||||
|
try:
|
||||||
|
if await redis_storage.is_healthy():
|
||||||
|
logger.debug(f"Using Redis storage for scope {scope.scope_id}")
|
||||||
|
instance = cls(
|
||||||
|
scope=scope,
|
||||||
|
storage=redis_storage,
|
||||||
|
default_ttl_seconds=default_ttl_seconds
|
||||||
|
or settings.working_memory_default_ttl_seconds,
|
||||||
|
)
|
||||||
|
await instance._initialize()
|
||||||
|
return instance
|
||||||
|
except MemoryConnectionError:
|
||||||
|
logger.warning("Redis unavailable, falling back to in-memory storage")
|
||||||
|
await redis_storage.close()
|
||||||
|
|
||||||
|
# Fall back to in-memory
|
||||||
|
storage = InMemoryStorage(
|
||||||
|
max_keys=settings.working_memory_max_items_per_session
|
||||||
|
)
|
||||||
|
instance = cls(
|
||||||
|
scope=scope,
|
||||||
|
storage=storage,
|
||||||
|
default_ttl_seconds=default_ttl_seconds
|
||||||
|
or settings.working_memory_default_ttl_seconds,
|
||||||
|
)
|
||||||
|
instance._using_fallback = True
|
||||||
|
await instance._initialize()
|
||||||
|
return instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def for_session(
|
||||||
|
cls,
|
||||||
|
session_id: str,
|
||||||
|
project_id: str | None = None,
|
||||||
|
agent_instance_id: str | None = None,
|
||||||
|
) -> "WorkingMemory":
|
||||||
|
"""
|
||||||
|
Convenience factory for session-scoped working memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Unique session identifier
|
||||||
|
project_id: Optional project context
|
||||||
|
agent_instance_id: Optional agent instance context
|
||||||
|
"""
|
||||||
|
# Build scope hierarchy
|
||||||
|
parent = None
|
||||||
|
if project_id:
|
||||||
|
parent = ScopeContext(
|
||||||
|
scope_type=ScopeLevel.PROJECT,
|
||||||
|
scope_id=project_id,
|
||||||
|
)
|
||||||
|
if agent_instance_id:
|
||||||
|
parent = ScopeContext(
|
||||||
|
scope_type=ScopeLevel.AGENT_INSTANCE,
|
||||||
|
scope_id=agent_instance_id,
|
||||||
|
parent=parent,
|
||||||
|
)
|
||||||
|
|
||||||
|
scope = ScopeContext(
|
||||||
|
scope_type=ScopeLevel.SESSION,
|
||||||
|
scope_id=session_id,
|
||||||
|
parent=parent,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await cls.create(scope=scope)
|
||||||
|
|
||||||
|
async def _initialize(self) -> None:
|
||||||
|
"""Initialize working memory metadata."""
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"scope_type": self._scope.scope_type.value,
|
||||||
|
"scope_id": self._scope.scope_id,
|
||||||
|
"created_at": datetime.now(UTC).isoformat(),
|
||||||
|
"using_fallback": self._using_fallback,
|
||||||
|
}
|
||||||
|
await self._storage.set(_METADATA_KEY, metadata)
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def scope(self) -> ScopeContext:
|
||||||
|
"""Get the scope context."""
|
||||||
|
return self._scope
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_using_fallback(self) -> bool:
|
||||||
|
"""Check if using fallback in-memory storage."""
|
||||||
|
return self._using_fallback
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Basic Key-Value Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Store a value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to store under
|
||||||
|
value: The value to store (must be JSON-serializable)
|
||||||
|
ttl_seconds: Optional TTL (uses default if not specified)
|
||||||
|
"""
|
||||||
|
if key.startswith("_"):
|
||||||
|
raise ValueError("Keys starting with '_' are reserved for internal use")
|
||||||
|
|
||||||
|
ttl = ttl_seconds if ttl_seconds is not None else self._default_ttl
|
||||||
|
await self._storage.set(key, value, ttl)
|
||||||
|
|
||||||
|
async def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""
|
||||||
|
Get a value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to retrieve
|
||||||
|
default: Default value if key not found
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The stored value or default
|
||||||
|
"""
|
||||||
|
result = await self._storage.get(key)
|
||||||
|
return result if result is not None else default
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the key existed
|
||||||
|
"""
|
||||||
|
if key.startswith("_"):
|
||||||
|
raise ValueError("Cannot delete internal keys directly")
|
||||||
|
return await self._storage.delete(key)
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a key exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the key exists
|
||||||
|
"""
|
||||||
|
return await self._storage.exists(key)
|
||||||
|
|
||||||
|
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||||
|
"""
|
||||||
|
List keys matching a pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Glob-style pattern (default "*" for all)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching keys (excludes internal keys)
|
||||||
|
"""
|
||||||
|
all_keys = await self._storage.list_keys(pattern)
|
||||||
|
return [k for k in all_keys if not k.startswith("_")]
|
||||||
|
|
||||||
|
async def get_all(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get all user key-value pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of all key-value pairs (excludes internal keys)
|
||||||
|
"""
|
||||||
|
all_data = await self._storage.get_all()
|
||||||
|
return {k: v for k, v in all_data.items() if not k.startswith("_")}
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear all user keys (preserves internal state).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys deleted
|
||||||
|
"""
|
||||||
|
# Save internal state
|
||||||
|
task_state = await self._storage.get(_TASK_STATE_KEY)
|
||||||
|
scratchpad = await self._storage.get(_SCRATCHPAD_KEY)
|
||||||
|
metadata = await self._storage.get(_METADATA_KEY)
|
||||||
|
|
||||||
|
count = await self._storage.clear()
|
||||||
|
|
||||||
|
# Restore internal state
|
||||||
|
if metadata is not None:
|
||||||
|
await self._storage.set(_METADATA_KEY, metadata)
|
||||||
|
if task_state is not None:
|
||||||
|
await self._storage.set(_TASK_STATE_KEY, task_state)
|
||||||
|
if scratchpad is not None:
|
||||||
|
await self._storage.set(_SCRATCHPAD_KEY, scratchpad)
|
||||||
|
|
||||||
|
# Adjust count for preserved keys
|
||||||
|
preserved = sum(1 for x in [task_state, scratchpad, metadata] if x is not None)
|
||||||
|
return max(0, count - preserved)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Task State Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def set_task_state(self, state: TaskState) -> None:
|
||||||
|
"""
|
||||||
|
Set the current task state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The task state to store
|
||||||
|
"""
|
||||||
|
state.updated_at = datetime.now(UTC)
|
||||||
|
await self._storage.set(_TASK_STATE_KEY, asdict(state))
|
||||||
|
|
||||||
|
async def get_task_state(self) -> TaskState | None:
|
||||||
|
"""
|
||||||
|
Get the current task state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current TaskState or None if not set
|
||||||
|
"""
|
||||||
|
data = await self._storage.get(_TASK_STATE_KEY)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Convert datetime strings back to datetime objects
|
||||||
|
if isinstance(data.get("started_at"), str):
|
||||||
|
data["started_at"] = datetime.fromisoformat(data["started_at"])
|
||||||
|
if isinstance(data.get("updated_at"), str):
|
||||||
|
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
|
||||||
|
|
||||||
|
return TaskState(**data)
|
||||||
|
|
||||||
|
async def update_task_progress(
|
||||||
|
self,
|
||||||
|
current_step: int | None = None,
|
||||||
|
progress_percent: float | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> TaskState | None:
|
||||||
|
"""
|
||||||
|
Update task progress fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_step: New current step number
|
||||||
|
progress_percent: New progress percentage (0.0 to 100.0)
|
||||||
|
status: New status string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated TaskState or None if no task state exists
|
||||||
|
"""
|
||||||
|
state = await self.get_task_state()
|
||||||
|
if state is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if current_step is not None:
|
||||||
|
state.current_step = current_step
|
||||||
|
if progress_percent is not None:
|
||||||
|
state.progress_percent = min(100.0, max(0.0, progress_percent))
|
||||||
|
if status is not None:
|
||||||
|
state.status = status
|
||||||
|
|
||||||
|
await self.set_task_state(state)
|
||||||
|
return state
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Scratchpad Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def append_scratchpad(self, content: str) -> None:
|
||||||
|
"""
|
||||||
|
Append content to the scratchpad.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Text to append
|
||||||
|
"""
|
||||||
|
settings = get_memory_settings()
|
||||||
|
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||||
|
|
||||||
|
# Check capacity
|
||||||
|
if len(entries) >= settings.working_memory_max_items_per_session:
|
||||||
|
# Remove oldest entries
|
||||||
|
entries = entries[-(settings.working_memory_max_items_per_session - 1) :]
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"content": content,
|
||||||
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
}
|
||||||
|
entries.append(entry)
|
||||||
|
await self._storage.set(_SCRATCHPAD_KEY, entries)
|
||||||
|
|
||||||
|
async def get_scratchpad(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all scratchpad entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scratchpad content strings (ordered by time)
|
||||||
|
"""
|
||||||
|
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||||
|
return [e["content"] for e in entries]
|
||||||
|
|
||||||
|
async def get_scratchpad_with_timestamps(self) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get all scratchpad entries with timestamps.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with 'content' and 'timestamp' keys
|
||||||
|
"""
|
||||||
|
return await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||||
|
|
||||||
|
async def clear_scratchpad(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear the scratchpad.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries cleared
|
||||||
|
"""
|
||||||
|
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||||
|
count = len(entries)
|
||||||
|
await self._storage.set(_SCRATCHPAD_KEY, [])
|
||||||
|
return count
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Checkpoint Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def create_checkpoint(self, description: str = "") -> str:
|
||||||
|
"""
|
||||||
|
Create a checkpoint of current state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: Optional description of the checkpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Checkpoint ID for later restoration
|
||||||
|
"""
|
||||||
|
checkpoint_id = str(uuid.uuid4())[:8]
|
||||||
|
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||||
|
|
||||||
|
# Capture all current state
|
||||||
|
all_data = await self._storage.get_all()
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"id": checkpoint_id,
|
||||||
|
"description": description,
|
||||||
|
"created_at": datetime.now(UTC).isoformat(),
|
||||||
|
"data": all_data,
|
||||||
|
}
|
||||||
|
|
||||||
|
await self._storage.set(checkpoint_key, checkpoint)
|
||||||
|
logger.debug(f"Created checkpoint {checkpoint_id}")
|
||||||
|
return checkpoint_id
|
||||||
|
|
||||||
|
async def restore_checkpoint(self, checkpoint_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Restore state from a checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_id: ID of the checkpoint to restore
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MemoryNotFoundError: If checkpoint not found
|
||||||
|
"""
|
||||||
|
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||||
|
checkpoint = await self._storage.get(checkpoint_key)
|
||||||
|
|
||||||
|
if checkpoint is None:
|
||||||
|
raise MemoryNotFoundError(f"Checkpoint {checkpoint_id} not found")
|
||||||
|
|
||||||
|
# Clear current state
|
||||||
|
await self._storage.clear()
|
||||||
|
|
||||||
|
# Restore all data from checkpoint
|
||||||
|
for key, value in checkpoint["data"].items():
|
||||||
|
await self._storage.set(key, value)
|
||||||
|
|
||||||
|
# Keep the checkpoint itself
|
||||||
|
await self._storage.set(checkpoint_key, checkpoint)
|
||||||
|
|
||||||
|
logger.debug(f"Restored checkpoint {checkpoint_id}")
|
||||||
|
|
||||||
|
async def list_checkpoints(self) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all available checkpoints.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of checkpoint metadata (id, description, created_at)
|
||||||
|
"""
|
||||||
|
checkpoint_keys = await self._storage.list_keys(f"{_CHECKPOINT_PREFIX}*")
|
||||||
|
checkpoints = []
|
||||||
|
|
||||||
|
for key in checkpoint_keys:
|
||||||
|
data = await self._storage.get(key)
|
||||||
|
if data:
|
||||||
|
checkpoints.append(
|
||||||
|
{
|
||||||
|
"id": data["id"],
|
||||||
|
"description": data["description"],
|
||||||
|
"created_at": data["created_at"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort by creation time
|
||||||
|
checkpoints.sort(key=lambda x: x["created_at"])
|
||||||
|
return checkpoints
|
||||||
|
|
||||||
|
async def delete_checkpoint(self, checkpoint_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_id: ID of the checkpoint to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if checkpoint existed
|
||||||
|
"""
|
||||||
|
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
|
||||||
|
return await self._storage.delete(checkpoint_key)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Health and Lifecycle
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
async def is_healthy(self) -> bool:
|
||||||
|
"""Check if the working memory storage is healthy."""
|
||||||
|
return await self._storage.is_healthy()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the working memory storage."""
|
||||||
|
if self._storage:
|
||||||
|
await self._storage.close()
|
||||||
|
|
||||||
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get working memory statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with stats about current state
|
||||||
|
"""
|
||||||
|
all_keys = await self._storage.list_keys("*")
|
||||||
|
user_keys = [k for k in all_keys if not k.startswith("_")]
|
||||||
|
checkpoint_keys = [k for k in all_keys if k.startswith(_CHECKPOINT_PREFIX)]
|
||||||
|
scratchpad = await self._storage.get(_SCRATCHPAD_KEY) or []
|
||||||
|
|
||||||
|
return {
|
||||||
|
"scope_type": self._scope.scope_type.value,
|
||||||
|
"scope_id": self._scope.scope_id,
|
||||||
|
"using_fallback": self._using_fallback,
|
||||||
|
"total_keys": len(all_keys),
|
||||||
|
"user_keys": len(user_keys),
|
||||||
|
"checkpoint_count": len(checkpoint_keys),
|
||||||
|
"scratchpad_entries": len(scratchpad),
|
||||||
|
"has_task_state": await self._storage.exists(_TASK_STATE_KEY),
|
||||||
|
}
|
||||||
406
backend/app/services/memory/working/storage.py
Normal file
406
backend/app/services/memory/working/storage.py
Normal file
@@ -0,0 +1,406 @@
|
|||||||
|
# app/services/memory/working/storage.py
|
||||||
|
"""
|
||||||
|
Working Memory Storage Backends.
|
||||||
|
|
||||||
|
Provides abstract storage interface and implementations:
|
||||||
|
- RedisStorage: Primary storage using Redis with connection pooling
|
||||||
|
- InMemoryStorage: Fallback storage when Redis is unavailable
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import fnmatch
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.services.memory.config import get_memory_settings
|
||||||
|
from app.services.memory.exceptions import (
|
||||||
|
MemoryConnectionError,
|
||||||
|
MemoryStorageError,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkingMemoryStorage(ABC):
|
||||||
|
"""Abstract base class for working memory storage backends."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store a value with optional TTL."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get(self, key: str) -> Any | None:
|
||||||
|
"""Get a value by key, returns None if not found or expired."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Delete a key, returns True if existed."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""Check if a key exists and is not expired."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||||
|
"""List all keys matching a pattern."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_all(self) -> dict[str, Any]:
|
||||||
|
"""Get all key-value pairs."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""Clear all keys, returns count of deleted keys."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def is_healthy(self) -> bool:
|
||||||
|
"""Check if the storage backend is healthy."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the storage connection."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryStorage(WorkingMemoryStorage):
|
||||||
|
"""
|
||||||
|
In-memory storage backend for working memory.
|
||||||
|
|
||||||
|
Used as fallback when Redis is unavailable. Data is not persisted
|
||||||
|
across restarts and is not shared between processes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_keys: int = 10000) -> None:
|
||||||
|
"""Initialize in-memory storage."""
|
||||||
|
self._data: dict[str, Any] = {}
|
||||||
|
self._expirations: dict[str, datetime] = {}
|
||||||
|
self._max_keys = max_keys
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def _is_expired(self, key: str) -> bool:
|
||||||
|
"""Check if a key has expired."""
|
||||||
|
if key not in self._expirations:
|
||||||
|
return False
|
||||||
|
return datetime.now(UTC) > self._expirations[key]
|
||||||
|
|
||||||
|
def _cleanup_expired(self) -> None:
|
||||||
|
"""Remove all expired keys."""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
expired_keys = [
|
||||||
|
key for key, exp_time in self._expirations.items() if now > exp_time
|
||||||
|
]
|
||||||
|
for key in expired_keys:
|
||||||
|
self._data.pop(key, None)
|
||||||
|
self._expirations.pop(key, None)
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store a value with optional TTL."""
|
||||||
|
async with self._lock:
|
||||||
|
# Cleanup expired keys periodically
|
||||||
|
if len(self._data) % 100 == 0:
|
||||||
|
self._cleanup_expired()
|
||||||
|
|
||||||
|
# Check capacity
|
||||||
|
if key not in self._data and len(self._data) >= self._max_keys:
|
||||||
|
# Evict expired keys first
|
||||||
|
self._cleanup_expired()
|
||||||
|
if len(self._data) >= self._max_keys:
|
||||||
|
raise MemoryStorageError(
|
||||||
|
f"Working memory capacity exceeded: {self._max_keys} keys"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._data[key] = value
|
||||||
|
if ttl_seconds is not None:
|
||||||
|
self._expirations[key] = datetime.now(UTC) + timedelta(
|
||||||
|
seconds=ttl_seconds
|
||||||
|
)
|
||||||
|
elif key in self._expirations:
|
||||||
|
# Remove existing expiration if no TTL specified
|
||||||
|
del self._expirations[key]
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Any | None:
|
||||||
|
"""Get a value by key."""
|
||||||
|
async with self._lock:
|
||||||
|
if key not in self._data:
|
||||||
|
return None
|
||||||
|
if self._is_expired(key):
|
||||||
|
del self._data[key]
|
||||||
|
del self._expirations[key]
|
||||||
|
return None
|
||||||
|
return self._data[key]
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Delete a key."""
|
||||||
|
async with self._lock:
|
||||||
|
existed = key in self._data
|
||||||
|
self._data.pop(key, None)
|
||||||
|
self._expirations.pop(key, None)
|
||||||
|
return existed
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""Check if a key exists and is not expired."""
|
||||||
|
async with self._lock:
|
||||||
|
if key not in self._data:
|
||||||
|
return False
|
||||||
|
if self._is_expired(key):
|
||||||
|
del self._data[key]
|
||||||
|
del self._expirations[key]
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||||
|
"""List all keys matching a pattern."""
|
||||||
|
async with self._lock:
|
||||||
|
self._cleanup_expired()
|
||||||
|
if pattern == "*":
|
||||||
|
return list(self._data.keys())
|
||||||
|
return [key for key in self._data.keys() if fnmatch.fnmatch(key, pattern)]
|
||||||
|
|
||||||
|
async def get_all(self) -> dict[str, Any]:
|
||||||
|
"""Get all key-value pairs."""
|
||||||
|
async with self._lock:
|
||||||
|
self._cleanup_expired()
|
||||||
|
return dict(self._data)
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""Clear all keys."""
|
||||||
|
async with self._lock:
|
||||||
|
count = len(self._data)
|
||||||
|
self._data.clear()
|
||||||
|
self._expirations.clear()
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def is_healthy(self) -> bool:
|
||||||
|
"""In-memory storage is always healthy."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""No cleanup needed for in-memory storage."""
|
||||||
|
|
||||||
|
|
||||||
|
class RedisStorage(WorkingMemoryStorage):
|
||||||
|
"""
|
||||||
|
Redis storage backend for working memory.
|
||||||
|
|
||||||
|
Primary storage with connection pooling, automatic reconnection,
|
||||||
|
and proper serialization of Python objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key_prefix: str = "",
|
||||||
|
connection_timeout: float = 5.0,
|
||||||
|
socket_timeout: float = 5.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize Redis storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_prefix: Prefix for all keys (e.g., "session:abc123:")
|
||||||
|
connection_timeout: Timeout for establishing connections
|
||||||
|
socket_timeout: Timeout for socket operations
|
||||||
|
"""
|
||||||
|
self._key_prefix = key_prefix
|
||||||
|
self._connection_timeout = connection_timeout
|
||||||
|
self._socket_timeout = socket_timeout
|
||||||
|
self._redis: Any = None
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def _make_key(self, key: str) -> str:
|
||||||
|
"""Add prefix to key."""
|
||||||
|
return f"{self._key_prefix}{key}"
|
||||||
|
|
||||||
|
def _strip_prefix(self, key: str) -> str:
|
||||||
|
"""Remove prefix from key."""
|
||||||
|
if key.startswith(self._key_prefix):
|
||||||
|
return key[len(self._key_prefix) :]
|
||||||
|
return key
|
||||||
|
|
||||||
|
def _serialize(self, value: Any) -> str:
|
||||||
|
"""Serialize a Python value to JSON string."""
|
||||||
|
return json.dumps(value, default=str)
|
||||||
|
|
||||||
|
def _deserialize(self, data: str | bytes | None) -> Any | None:
|
||||||
|
"""Deserialize a JSON string to Python value."""
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
data = data.decode("utf-8")
|
||||||
|
return json.loads(data)
|
||||||
|
|
||||||
|
async def _get_client(self) -> Any:
|
||||||
|
"""Get or create Redis client."""
|
||||||
|
if self._redis is not None:
|
||||||
|
return self._redis
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
if self._redis is not None:
|
||||||
|
return self._redis
|
||||||
|
|
||||||
|
try:
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
except ImportError as e:
|
||||||
|
raise MemoryConnectionError(
|
||||||
|
"redis package not installed. Install with: pip install redis"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
settings = get_memory_settings()
|
||||||
|
redis_url = settings.redis_url
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._redis = await aioredis.from_url(
|
||||||
|
redis_url,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True,
|
||||||
|
socket_connect_timeout=self._connection_timeout,
|
||||||
|
socket_timeout=self._socket_timeout,
|
||||||
|
)
|
||||||
|
# Test connection
|
||||||
|
await self._redis.ping()
|
||||||
|
logger.info("Connected to Redis for working memory")
|
||||||
|
return self._redis
|
||||||
|
except Exception as e:
|
||||||
|
self._redis = None
|
||||||
|
raise MemoryConnectionError(f"Failed to connect to Redis: {e}") from e
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store a value with optional TTL."""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
full_key = self._make_key(key)
|
||||||
|
serialized = self._serialize(value)
|
||||||
|
|
||||||
|
if ttl_seconds is not None:
|
||||||
|
await client.setex(full_key, ttl_seconds, serialized)
|
||||||
|
else:
|
||||||
|
await client.set(full_key, serialized)
|
||||||
|
except MemoryConnectionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise MemoryStorageError(f"Failed to set key {key}: {e}") from e
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Any | None:
|
||||||
|
"""Get a value by key."""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
full_key = self._make_key(key)
|
||||||
|
data = await client.get(full_key)
|
||||||
|
return self._deserialize(data)
|
||||||
|
except MemoryConnectionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise MemoryStorageError(f"Failed to get key {key}: {e}") from e
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""Delete a key."""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
full_key = self._make_key(key)
|
||||||
|
result = await client.delete(full_key)
|
||||||
|
return bool(result)
|
||||||
|
except MemoryConnectionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise MemoryStorageError(f"Failed to delete key {key}: {e}") from e
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""Check if a key exists."""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
full_key = self._make_key(key)
|
||||||
|
result = await client.exists(full_key)
|
||||||
|
return bool(result)
|
||||||
|
except MemoryConnectionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise MemoryStorageError(f"Failed to check key {key}: {e}") from e
|
||||||
|
|
||||||
|
async def list_keys(self, pattern: str = "*") -> list[str]:
|
||||||
|
"""List all keys matching a pattern."""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
full_pattern = self._make_key(pattern)
|
||||||
|
keys = await client.keys(full_pattern)
|
||||||
|
return [self._strip_prefix(key) for key in keys]
|
||||||
|
except MemoryConnectionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise MemoryStorageError(f"Failed to list keys: {e}") from e
|
||||||
|
|
||||||
|
async def get_all(self) -> dict[str, Any]:
|
||||||
|
"""Get all key-value pairs."""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
full_pattern = self._make_key("*")
|
||||||
|
keys = await client.keys(full_pattern)
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
values = await client.mget(*keys)
|
||||||
|
result = {}
|
||||||
|
for key, value in zip(keys, values, strict=False):
|
||||||
|
stripped_key = self._strip_prefix(key)
|
||||||
|
result[stripped_key] = self._deserialize(value)
|
||||||
|
return result
|
||||||
|
except MemoryConnectionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise MemoryStorageError(f"Failed to get all keys: {e}") from e
|
||||||
|
|
||||||
|
async def clear(self) -> int:
|
||||||
|
"""Clear all keys with this prefix."""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
full_pattern = self._make_key("*")
|
||||||
|
keys = await client.keys(full_pattern)
|
||||||
|
|
||||||
|
if not keys:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return await client.delete(*keys)
|
||||||
|
except MemoryConnectionError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise MemoryStorageError(f"Failed to clear keys: {e}") from e
|
||||||
|
|
||||||
|
async def is_healthy(self) -> bool:
|
||||||
|
"""Check if Redis connection is healthy."""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
await client.ping()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the Redis connection."""
|
||||||
|
if self._redis is not None:
|
||||||
|
await self._redis.close()
|
||||||
|
self._redis = None
|
||||||
2
backend/tests/unit/services/memory/working/__init__.py
Normal file
2
backend/tests/unit/services/memory/working/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# tests/unit/services/memory/working/__init__.py
|
||||||
|
"""Unit tests for working memory implementation."""
|
||||||
391
backend/tests/unit/services/memory/working/test_memory.py
Normal file
391
backend/tests/unit/services/memory/working/test_memory.py
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
# tests/unit/services/memory/working/test_memory.py
|
||||||
|
"""Unit tests for WorkingMemory class."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.memory.exceptions import MemoryNotFoundError
|
||||||
|
from app.services.memory.types import ScopeContext, ScopeLevel, TaskState
|
||||||
|
from app.services.memory.working.memory import WorkingMemory
|
||||||
|
from app.services.memory.working.storage import InMemoryStorage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def scope() -> ScopeContext:
|
||||||
|
"""Create a test scope."""
|
||||||
|
return ScopeContext(
|
||||||
|
scope_type=ScopeLevel.SESSION,
|
||||||
|
scope_id="test-session-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage() -> InMemoryStorage:
|
||||||
|
"""Create a test storage backend."""
|
||||||
|
return InMemoryStorage(max_keys=1000)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def memory(scope: ScopeContext, storage: InMemoryStorage) -> WorkingMemory:
|
||||||
|
"""Create a WorkingMemory instance for testing."""
|
||||||
|
wm = WorkingMemory(scope=scope, storage=storage)
|
||||||
|
await wm._initialize()
|
||||||
|
return wm
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkingMemoryBasicOperations:
|
||||||
|
"""Tests for basic key-value operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_and_get(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test basic set and get."""
|
||||||
|
await memory.set("key1", "value1")
|
||||||
|
result = await memory.get("key1")
|
||||||
|
assert result == "value1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_with_default(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test get with default value."""
|
||||||
|
result = await memory.get("nonexistent", default="fallback")
|
||||||
|
assert result == "fallback"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test delete operation."""
|
||||||
|
await memory.set("key1", "value1")
|
||||||
|
result = await memory.delete("key1")
|
||||||
|
assert result is True
|
||||||
|
assert await memory.exists("key1") is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exists(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test exists check."""
|
||||||
|
await memory.set("key1", "value1")
|
||||||
|
assert await memory.exists("key1") is True
|
||||||
|
assert await memory.exists("nonexistent") is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reserved_key_prefix(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test that keys starting with _ are rejected."""
|
||||||
|
with pytest.raises(ValueError, match="reserved"):
|
||||||
|
await memory.set("_internal", "value")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cannot_delete_internal_keys(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test that internal keys cannot be deleted directly."""
|
||||||
|
with pytest.raises(ValueError, match="internal"):
|
||||||
|
await memory.delete("_task_state")
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkingMemoryListAndClear:
|
||||||
|
"""Tests for list and clear operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test listing keys."""
|
||||||
|
await memory.set("key1", "value1")
|
||||||
|
await memory.set("key2", "value2")
|
||||||
|
|
||||||
|
keys = await memory.list_keys()
|
||||||
|
assert set(keys) == {"key1", "key2"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_excludes_internal(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test that list_keys excludes internal keys."""
|
||||||
|
await memory.set("user_key", "value")
|
||||||
|
# Internal keys exist from initialization
|
||||||
|
keys = await memory.list_keys()
|
||||||
|
assert all(not k.startswith("_") for k in keys)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_with_pattern(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test listing keys with pattern."""
|
||||||
|
await memory.set("prefix_a", "value1")
|
||||||
|
await memory.set("prefix_b", "value2")
|
||||||
|
await memory.set("other", "value3")
|
||||||
|
|
||||||
|
keys = await memory.list_keys("prefix_*")
|
||||||
|
assert set(keys) == {"prefix_a", "prefix_b"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test getting all key-value pairs."""
|
||||||
|
await memory.set("key1", "value1")
|
||||||
|
await memory.set("key2", "value2")
|
||||||
|
|
||||||
|
result = await memory.get_all()
|
||||||
|
assert result == {"key1": "value1", "key2": "value2"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_preserves_internal_state(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test that clear preserves internal state."""
|
||||||
|
# Set some user data
|
||||||
|
await memory.set("user_key", "value")
|
||||||
|
|
||||||
|
# Set task state
|
||||||
|
state = TaskState(
|
||||||
|
task_id="task-1",
|
||||||
|
task_type="test",
|
||||||
|
description="Test task",
|
||||||
|
)
|
||||||
|
await memory.set_task_state(state)
|
||||||
|
|
||||||
|
# Clear
|
||||||
|
await memory.clear()
|
||||||
|
|
||||||
|
# User data should be gone
|
||||||
|
assert await memory.exists("user_key") is False
|
||||||
|
|
||||||
|
# Task state should be preserved
|
||||||
|
restored_state = await memory.get_task_state()
|
||||||
|
assert restored_state is not None
|
||||||
|
assert restored_state.task_id == "task-1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkingMemoryTaskState:
|
||||||
|
"""Tests for task state operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_and_get_task_state(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test setting and getting task state."""
|
||||||
|
state = TaskState(
|
||||||
|
task_id="task-123",
|
||||||
|
task_type="code_review",
|
||||||
|
description="Review pull request",
|
||||||
|
status="in_progress",
|
||||||
|
current_step=2,
|
||||||
|
total_steps=5,
|
||||||
|
progress_percent=40.0,
|
||||||
|
context={"pr_id": 456},
|
||||||
|
)
|
||||||
|
|
||||||
|
await memory.set_task_state(state)
|
||||||
|
result = await memory.get_task_state()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.task_id == "task-123"
|
||||||
|
assert result.task_type == "code_review"
|
||||||
|
assert result.status == "in_progress"
|
||||||
|
assert result.current_step == 2
|
||||||
|
assert result.progress_percent == 40.0
|
||||||
|
assert result.context == {"pr_id": 456}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_task_state_none_when_not_set(
|
||||||
|
self, memory: WorkingMemory
|
||||||
|
) -> None:
|
||||||
|
"""Test that get_task_state returns None when not set."""
|
||||||
|
result = await memory.get_task_state()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_task_progress(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test updating task progress."""
|
||||||
|
state = TaskState(
|
||||||
|
task_id="task-123",
|
||||||
|
task_type="test",
|
||||||
|
description="Test",
|
||||||
|
current_step=1,
|
||||||
|
progress_percent=10.0,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
await memory.set_task_state(state)
|
||||||
|
|
||||||
|
updated = await memory.update_task_progress(
|
||||||
|
current_step=3,
|
||||||
|
progress_percent=60.0,
|
||||||
|
status="processing",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.current_step == 3
|
||||||
|
assert updated.progress_percent == 60.0
|
||||||
|
assert updated.status == "processing"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_task_progress_clamps_percent(
|
||||||
|
self, memory: WorkingMemory
|
||||||
|
) -> None:
|
||||||
|
"""Test that progress percent is clamped to 0-100."""
|
||||||
|
state = TaskState(
|
||||||
|
task_id="task-123",
|
||||||
|
task_type="test",
|
||||||
|
description="Test",
|
||||||
|
)
|
||||||
|
await memory.set_task_state(state)
|
||||||
|
|
||||||
|
updated = await memory.update_task_progress(progress_percent=150.0)
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.progress_percent == 100.0
|
||||||
|
|
||||||
|
updated = await memory.update_task_progress(progress_percent=-10.0)
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.progress_percent == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkingMemoryScratchpad:
|
||||||
|
"""Tests for scratchpad operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_append_and_get_scratchpad(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test appending to and getting scratchpad."""
|
||||||
|
await memory.append_scratchpad("First note")
|
||||||
|
await memory.append_scratchpad("Second note")
|
||||||
|
|
||||||
|
entries = await memory.get_scratchpad()
|
||||||
|
assert entries == ["First note", "Second note"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_scratchpad_empty(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test getting empty scratchpad."""
|
||||||
|
entries = await memory.get_scratchpad()
|
||||||
|
assert entries == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_scratchpad_with_timestamps(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test getting scratchpad with timestamps."""
|
||||||
|
await memory.append_scratchpad("Test note")
|
||||||
|
|
||||||
|
entries = await memory.get_scratchpad_with_timestamps()
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0]["content"] == "Test note"
|
||||||
|
assert "timestamp" in entries[0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_scratchpad(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test clearing scratchpad."""
|
||||||
|
await memory.append_scratchpad("Note 1")
|
||||||
|
await memory.append_scratchpad("Note 2")
|
||||||
|
|
||||||
|
count = await memory.clear_scratchpad()
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
entries = await memory.get_scratchpad()
|
||||||
|
assert entries == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkingMemoryCheckpoints:
|
||||||
|
"""Tests for checkpoint operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_checkpoint(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test creating a checkpoint."""
|
||||||
|
await memory.set("key1", "value1")
|
||||||
|
await memory.set("key2", "value2")
|
||||||
|
|
||||||
|
checkpoint_id = await memory.create_checkpoint("Test checkpoint")
|
||||||
|
|
||||||
|
assert checkpoint_id is not None
|
||||||
|
assert len(checkpoint_id) == 8 # UUID prefix
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_restore_checkpoint(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test restoring from a checkpoint."""
|
||||||
|
await memory.set("key1", "original")
|
||||||
|
checkpoint_id = await memory.create_checkpoint()
|
||||||
|
|
||||||
|
# Modify state
|
||||||
|
await memory.set("key1", "modified")
|
||||||
|
await memory.set("key2", "new")
|
||||||
|
|
||||||
|
# Restore
|
||||||
|
await memory.restore_checkpoint(checkpoint_id)
|
||||||
|
|
||||||
|
# Check restoration
|
||||||
|
assert await memory.get("key1") == "original"
|
||||||
|
# key2 didn't exist in checkpoint, so it should be gone
|
||||||
|
# But due to checkpoint being restored with clear, it's gone
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_restore_nonexistent_checkpoint(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test restoring from nonexistent checkpoint raises error."""
|
||||||
|
with pytest.raises(MemoryNotFoundError):
|
||||||
|
await memory.restore_checkpoint("nonexistent")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test listing checkpoints."""
|
||||||
|
cp1 = await memory.create_checkpoint("First")
|
||||||
|
cp2 = await memory.create_checkpoint("Second")
|
||||||
|
|
||||||
|
checkpoints = await memory.list_checkpoints()
|
||||||
|
|
||||||
|
assert len(checkpoints) == 2
|
||||||
|
ids = [cp["id"] for cp in checkpoints]
|
||||||
|
assert cp1 in ids
|
||||||
|
assert cp2 in ids
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_checkpoint(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test deleting a checkpoint."""
|
||||||
|
checkpoint_id = await memory.create_checkpoint()
|
||||||
|
|
||||||
|
result = await memory.delete_checkpoint(checkpoint_id)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
checkpoints = await memory.list_checkpoints()
|
||||||
|
assert len(checkpoints) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkingMemoryScope:
|
||||||
|
"""Tests for scope handling."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scope_property(
|
||||||
|
self, memory: WorkingMemory, scope: ScopeContext
|
||||||
|
) -> None:
|
||||||
|
"""Test scope property."""
|
||||||
|
assert memory.scope == scope
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_for_session_factory(self) -> None:
|
||||||
|
"""Test for_session factory method."""
|
||||||
|
# This would normally try Redis and fall back to in-memory
|
||||||
|
# In tests, Redis won't be available, so it uses fallback
|
||||||
|
wm = await WorkingMemory.for_session(
|
||||||
|
session_id="session-abc",
|
||||||
|
project_id="project-123",
|
||||||
|
agent_instance_id="agent-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert wm.scope.scope_type == ScopeLevel.SESSION
|
||||||
|
assert wm.scope.scope_id == "session-abc"
|
||||||
|
assert wm.scope.parent is not None
|
||||||
|
assert wm.scope.parent.scope_type == ScopeLevel.AGENT_INSTANCE
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkingMemoryHealth:
|
||||||
|
"""Tests for health and lifecycle."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_healthy(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test health check."""
|
||||||
|
assert await memory.is_healthy() is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test getting stats."""
|
||||||
|
await memory.set("key1", "value1")
|
||||||
|
await memory.append_scratchpad("Note")
|
||||||
|
|
||||||
|
state = TaskState(task_id="t1", task_type="test", description="Test")
|
||||||
|
await memory.set_task_state(state)
|
||||||
|
|
||||||
|
stats = await memory.get_stats()
|
||||||
|
|
||||||
|
assert stats["scope_type"] == "session"
|
||||||
|
assert stats["scope_id"] == "test-session-123"
|
||||||
|
assert stats["user_keys"] == 1
|
||||||
|
assert stats["scratchpad_entries"] == 1
|
||||||
|
assert stats["has_task_state"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_using_fallback(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test fallback detection."""
|
||||||
|
# In-memory storage is always fallback
|
||||||
|
assert memory.is_using_fallback is False # Not set in fixture
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close(self, memory: WorkingMemory) -> None:
|
||||||
|
"""Test close doesn't error."""
|
||||||
|
await memory.close() # Should not raise
|
||||||
303
backend/tests/unit/services/memory/working/test_storage.py
Normal file
303
backend/tests/unit/services/memory/working/test_storage.py
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
# tests/unit/services/memory/working/test_storage.py
|
||||||
|
"""Unit tests for working memory storage backends."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.memory.exceptions import MemoryStorageError
|
||||||
|
from app.services.memory.working.storage import InMemoryStorage
|
||||||
|
|
||||||
|
|
||||||
|
class TestInMemoryStorageBasicOperations:
|
||||||
|
"""Tests for basic InMemoryStorage operations."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage(self) -> InMemoryStorage:
|
||||||
|
"""Create a fresh storage instance."""
|
||||||
|
return InMemoryStorage(max_keys=100)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_and_get(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test basic set and get."""
|
||||||
|
await storage.set("key1", "value1")
|
||||||
|
result = await storage.get("key1")
|
||||||
|
assert result == "value1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_nonexistent_key(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test getting a key that doesn't exist."""
|
||||||
|
result = await storage.get("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_overwrites_existing(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test that set overwrites existing values."""
|
||||||
|
await storage.set("key1", "original")
|
||||||
|
await storage.set("key1", "updated")
|
||||||
|
result = await storage.get("key1")
|
||||||
|
assert result == "updated"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_existing_key(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test deleting an existing key."""
|
||||||
|
await storage.set("key1", "value1")
|
||||||
|
result = await storage.delete("key1")
|
||||||
|
assert result is True
|
||||||
|
assert await storage.get("key1") is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_nonexistent_key(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test deleting a key that doesn't exist."""
|
||||||
|
result = await storage.delete("nonexistent")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exists(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test exists check."""
|
||||||
|
await storage.set("key1", "value1")
|
||||||
|
assert await storage.exists("key1") is True
|
||||||
|
assert await storage.exists("nonexistent") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestInMemoryStorageTTL:
|
||||||
|
"""Tests for TTL functionality."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage(self) -> InMemoryStorage:
|
||||||
|
"""Create a fresh storage instance."""
|
||||||
|
return InMemoryStorage(max_keys=100)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_with_ttl(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test that TTL is stored correctly."""
|
||||||
|
await storage.set("key1", "value1", ttl_seconds=10)
|
||||||
|
# Key should exist immediately
|
||||||
|
assert await storage.exists("key1") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ttl_expiration(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test that expired keys return None."""
|
||||||
|
await storage.set("key1", "value1", ttl_seconds=1)
|
||||||
|
|
||||||
|
# Key exists initially
|
||||||
|
assert await storage.get("key1") == "value1"
|
||||||
|
|
||||||
|
# Wait for expiration
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# Key should be expired
|
||||||
|
assert await storage.get("key1") is None
|
||||||
|
assert await storage.exists("key1") is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_ttl_on_update(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test that updating without TTL removes expiration."""
|
||||||
|
await storage.set("key1", "value1", ttl_seconds=1)
|
||||||
|
await storage.set("key1", "value2") # No TTL
|
||||||
|
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# Key should still exist (TTL removed)
|
||||||
|
assert await storage.get("key1") == "value2"
|
||||||
|
|
||||||
|
|
||||||
|
class TestInMemoryStorageListAndClear:
|
||||||
|
"""Tests for list and clear operations."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage(self) -> InMemoryStorage:
|
||||||
|
"""Create a fresh storage instance."""
|
||||||
|
return InMemoryStorage(max_keys=100)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_all(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test listing all keys."""
|
||||||
|
await storage.set("key1", "value1")
|
||||||
|
await storage.set("key2", "value2")
|
||||||
|
await storage.set("other", "value3")
|
||||||
|
|
||||||
|
keys = await storage.list_keys()
|
||||||
|
assert set(keys) == {"key1", "key2", "other"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_with_pattern(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test listing keys with pattern."""
|
||||||
|
await storage.set("key1", "value1")
|
||||||
|
await storage.set("key2", "value2")
|
||||||
|
await storage.set("other", "value3")
|
||||||
|
|
||||||
|
keys = await storage.list_keys("key*")
|
||||||
|
assert set(keys) == {"key1", "key2"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test getting all key-value pairs."""
|
||||||
|
await storage.set("key1", "value1")
|
||||||
|
await storage.set("key2", "value2")
|
||||||
|
|
||||||
|
result = await storage.get_all()
|
||||||
|
assert result == {"key1": "value1", "key2": "value2"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test clearing all keys."""
|
||||||
|
await storage.set("key1", "value1")
|
||||||
|
await storage.set("key2", "value2")
|
||||||
|
|
||||||
|
count = await storage.clear()
|
||||||
|
assert count == 2
|
||||||
|
assert await storage.get_all() == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestInMemoryStorageCapacity:
|
||||||
|
"""Tests for capacity limits."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_capacity_limit_exceeded(self) -> None:
|
||||||
|
"""Test that exceeding capacity raises error."""
|
||||||
|
storage = InMemoryStorage(max_keys=2)
|
||||||
|
|
||||||
|
await storage.set("key1", "value1")
|
||||||
|
await storage.set("key2", "value2")
|
||||||
|
|
||||||
|
with pytest.raises(MemoryStorageError, match="capacity exceeded"):
|
||||||
|
await storage.set("key3", "value3")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_existing_key_within_capacity(self) -> None:
|
||||||
|
"""Test that updating existing key doesn't count against capacity."""
|
||||||
|
storage = InMemoryStorage(max_keys=2)
|
||||||
|
|
||||||
|
await storage.set("key1", "value1")
|
||||||
|
await storage.set("key2", "value2")
|
||||||
|
await storage.set("key1", "updated") # Should succeed
|
||||||
|
|
||||||
|
assert await storage.get("key1") == "updated"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_expired_keys_freed_for_capacity(self) -> None:
|
||||||
|
"""Test that expired keys are cleaned up for capacity."""
|
||||||
|
storage = InMemoryStorage(max_keys=2)
|
||||||
|
|
||||||
|
await storage.set("key1", "value1", ttl_seconds=1)
|
||||||
|
await storage.set("key2", "value2")
|
||||||
|
|
||||||
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
|
# Should succeed because key1 is expired and will be cleaned
|
||||||
|
await storage.set("key3", "value3")
|
||||||
|
assert await storage.get("key3") == "value3"
|
||||||
|
|
||||||
|
|
||||||
|
class TestInMemoryStorageDataTypes:
|
||||||
|
"""Tests for different data types."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage(self) -> InMemoryStorage:
|
||||||
|
"""Create a fresh storage instance."""
|
||||||
|
return InMemoryStorage(max_keys=100)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_dict(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test storing dict values."""
|
||||||
|
data = {"nested": {"key": "value"}, "list": [1, 2, 3]}
|
||||||
|
await storage.set("dict_key", data)
|
||||||
|
result = await storage.get("dict_key")
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_list(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test storing list values."""
|
||||||
|
data = [1, 2, {"nested": "dict"}]
|
||||||
|
await storage.set("list_key", data)
|
||||||
|
result = await storage.get("list_key")
|
||||||
|
assert result == data
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_numbers(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test storing numeric values."""
|
||||||
|
await storage.set("int_key", 42)
|
||||||
|
await storage.set("float_key", 3.14)
|
||||||
|
|
||||||
|
assert await storage.get("int_key") == 42
|
||||||
|
assert await storage.get("float_key") == 3.14
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_boolean(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test storing boolean values."""
|
||||||
|
await storage.set("true_key", True)
|
||||||
|
await storage.set("false_key", False)
|
||||||
|
|
||||||
|
assert await storage.get("true_key") is True
|
||||||
|
assert await storage.get("false_key") is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_none(self, storage: InMemoryStorage) -> None:
|
||||||
|
"""Test storing None value."""
|
||||||
|
await storage.set("none_key", None)
|
||||||
|
# Note: None is stored, but get returns None for both missing and None values
|
||||||
|
# Use exists to distinguish
|
||||||
|
assert await storage.exists("none_key") is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestInMemoryStorageHealth:
|
||||||
|
"""Tests for health and lifecycle."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_healthy(self) -> None:
|
||||||
|
"""Test health check."""
|
||||||
|
storage = InMemoryStorage()
|
||||||
|
assert await storage.is_healthy() is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close(self) -> None:
|
||||||
|
"""Test close is no-op but doesn't error."""
|
||||||
|
storage = InMemoryStorage()
|
||||||
|
await storage.close() # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestInMemoryStorageConcurrency:
|
||||||
|
"""Tests for concurrent access."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_writes(self) -> None:
|
||||||
|
"""Test concurrent write operations don't corrupt data."""
|
||||||
|
storage = InMemoryStorage(max_keys=1000)
|
||||||
|
|
||||||
|
async def write_batch(prefix: str, count: int) -> None:
|
||||||
|
for i in range(count):
|
||||||
|
await storage.set(f"{prefix}_{i}", f"value_{i}")
|
||||||
|
|
||||||
|
# Run concurrent writes
|
||||||
|
await asyncio.gather(
|
||||||
|
write_batch("a", 100),
|
||||||
|
write_batch("b", 100),
|
||||||
|
write_batch("c", 100),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify all writes succeeded
|
||||||
|
keys = await storage.list_keys()
|
||||||
|
assert len(keys) == 300
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_read_write(self) -> None:
|
||||||
|
"""Test concurrent read and write operations."""
|
||||||
|
storage = InMemoryStorage()
|
||||||
|
await storage.set("key", 0)
|
||||||
|
|
||||||
|
async def increment() -> None:
|
||||||
|
for _ in range(100):
|
||||||
|
val = await storage.get("key") or 0
|
||||||
|
await storage.set("key", val + 1)
|
||||||
|
|
||||||
|
# Run concurrent increments
|
||||||
|
await asyncio.gather(
|
||||||
|
increment(),
|
||||||
|
increment(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Final value depends on interleaving
|
||||||
|
# Just verify we don't crash and value is positive
|
||||||
|
result = await storage.get("key")
|
||||||
|
assert result > 0
|
||||||
Reference in New Issue
Block a user