feat(memory): add working memory implementation (Issue #89)

Implements session-scoped ephemeral memory with:

Storage Backends:
- InMemoryStorage: Thread-safe fallback with TTL support and capacity limits
- RedisStorage: Primary storage with connection pooling and JSON serialization
- Auto-fallback from Redis to in-memory when unavailable

WorkingMemory Class:
- Key-value storage with TTL and reserved key protection
- Task state tracking with progress updates
- Scratchpad for reasoning steps with timestamps
- Checkpoint/snapshot support for recovery
- Factory methods for auto-configured storage

Tests:
- 55 unit tests covering all functionality
- Tests for basic ops, TTL, capacity, concurrency
- Tests for task state, scratchpad, checkpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 01:51:03 +01:00
parent c9d8c0835c
commit 4974233169
7 changed files with 1673 additions and 4 deletions

View File

@@ -94,6 +94,22 @@ class MemoryStorageError(MemoryError):
self.backend = backend
class MemoryConnectionError(MemoryError):
"""Raised when memory storage connection fails."""
def __init__(
self,
message: str = "Memory connection failed",
*,
backend: str | None = None,
host: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(message, **kwargs)
self.backend = backend
self.host = host
class MemorySerializationError(MemoryError):
"""Raised when memory serialization/deserialization fails."""

View File

@@ -1,8 +1,16 @@
# app/services/memory/working/__init__.py
"""
Working Memory
Working Memory Implementation.
Session-scoped ephemeral memory for current task state,
variables, and scratchpad.
Provides short-term memory storage with Redis primary and in-memory fallback.
"""
# Will be populated in #89
from .memory import WorkingMemory
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
__all__ = [
"InMemoryStorage",
"RedisStorage",
"WorkingMemory",
"WorkingMemoryStorage",
]

View File

@@ -0,0 +1,543 @@
# app/services/memory/working/memory.py
"""
Working Memory Implementation.
Provides session-scoped ephemeral memory with:
- Key-value storage with TTL
- Task state tracking
- Scratchpad for reasoning steps
- Checkpoint/snapshot support
"""
import logging
import uuid
from dataclasses import asdict
from datetime import UTC, datetime
from typing import Any
from app.services.memory.config import get_memory_settings
from app.services.memory.exceptions import (
MemoryConnectionError,
MemoryNotFoundError,
)
from app.services.memory.types import ScopeContext, ScopeLevel, TaskState
from .storage import InMemoryStorage, RedisStorage, WorkingMemoryStorage
logger = logging.getLogger(__name__)
# Reserved key prefixes for internal use
_TASK_STATE_KEY = "_task_state"
_SCRATCHPAD_KEY = "_scratchpad"
_CHECKPOINT_PREFIX = "_checkpoint:"
_METADATA_KEY = "_metadata"
class WorkingMemory:
"""
Session-scoped working memory.
Provides ephemeral storage for agent's current task context:
- Variables and intermediate data
- Task state (current step, status, progress)
- Scratchpad for reasoning steps
- Checkpoints for recovery
Uses Redis as primary storage with in-memory fallback.
"""
def __init__(
self,
scope: ScopeContext,
storage: WorkingMemoryStorage,
default_ttl_seconds: int | None = None,
) -> None:
"""
Initialize working memory for a scope.
Args:
scope: The scope context (session, agent instance, etc.)
storage: Storage backend (use create() factory for auto-configuration)
default_ttl_seconds: Default TTL for keys (None = no expiration)
"""
self._scope = scope
self._storage: WorkingMemoryStorage = storage
self._default_ttl = default_ttl_seconds
self._using_fallback = False
self._initialized = False
@classmethod
async def create(
cls,
scope: ScopeContext,
default_ttl_seconds: int | None = None,
) -> "WorkingMemory":
"""
Factory method to create WorkingMemory with auto-configured storage.
Attempts Redis first, falls back to in-memory if unavailable.
"""
settings = get_memory_settings()
key_prefix = f"wm:{scope.to_key_prefix()}:"
storage: WorkingMemoryStorage
# Try Redis first
if settings.working_memory_backend == "redis":
redis_storage = RedisStorage(key_prefix=key_prefix)
try:
if await redis_storage.is_healthy():
logger.debug(f"Using Redis storage for scope {scope.scope_id}")
instance = cls(
scope=scope,
storage=redis_storage,
default_ttl_seconds=default_ttl_seconds
or settings.working_memory_default_ttl_seconds,
)
await instance._initialize()
return instance
except MemoryConnectionError:
logger.warning("Redis unavailable, falling back to in-memory storage")
await redis_storage.close()
# Fall back to in-memory
storage = InMemoryStorage(
max_keys=settings.working_memory_max_items_per_session
)
instance = cls(
scope=scope,
storage=storage,
default_ttl_seconds=default_ttl_seconds
or settings.working_memory_default_ttl_seconds,
)
instance._using_fallback = True
await instance._initialize()
return instance
@classmethod
async def for_session(
cls,
session_id: str,
project_id: str | None = None,
agent_instance_id: str | None = None,
) -> "WorkingMemory":
"""
Convenience factory for session-scoped working memory.
Args:
session_id: Unique session identifier
project_id: Optional project context
agent_instance_id: Optional agent instance context
"""
# Build scope hierarchy
parent = None
if project_id:
parent = ScopeContext(
scope_type=ScopeLevel.PROJECT,
scope_id=project_id,
)
if agent_instance_id:
parent = ScopeContext(
scope_type=ScopeLevel.AGENT_INSTANCE,
scope_id=agent_instance_id,
parent=parent,
)
scope = ScopeContext(
scope_type=ScopeLevel.SESSION,
scope_id=session_id,
parent=parent,
)
return await cls.create(scope=scope)
async def _initialize(self) -> None:
"""Initialize working memory metadata."""
if self._initialized:
return
metadata = {
"scope_type": self._scope.scope_type.value,
"scope_id": self._scope.scope_id,
"created_at": datetime.now(UTC).isoformat(),
"using_fallback": self._using_fallback,
}
await self._storage.set(_METADATA_KEY, metadata)
self._initialized = True
@property
def scope(self) -> ScopeContext:
"""Get the scope context."""
return self._scope
@property
def is_using_fallback(self) -> bool:
"""Check if using fallback in-memory storage."""
return self._using_fallback
# =========================================================================
# Basic Key-Value Operations
# =========================================================================
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""
Store a value.
Args:
key: The key to store under
value: The value to store (must be JSON-serializable)
ttl_seconds: Optional TTL (uses default if not specified)
"""
if key.startswith("_"):
raise ValueError("Keys starting with '_' are reserved for internal use")
ttl = ttl_seconds if ttl_seconds is not None else self._default_ttl
await self._storage.set(key, value, ttl)
async def get(self, key: str, default: Any = None) -> Any:
"""
Get a value.
Args:
key: The key to retrieve
default: Default value if key not found
Returns:
The stored value or default
"""
result = await self._storage.get(key)
return result if result is not None else default
async def delete(self, key: str) -> bool:
"""
Delete a key.
Args:
key: The key to delete
Returns:
True if the key existed
"""
if key.startswith("_"):
raise ValueError("Cannot delete internal keys directly")
return await self._storage.delete(key)
async def exists(self, key: str) -> bool:
"""
Check if a key exists.
Args:
key: The key to check
Returns:
True if the key exists
"""
return await self._storage.exists(key)
async def list_keys(self, pattern: str = "*") -> list[str]:
"""
List keys matching a pattern.
Args:
pattern: Glob-style pattern (default "*" for all)
Returns:
List of matching keys (excludes internal keys)
"""
all_keys = await self._storage.list_keys(pattern)
return [k for k in all_keys if not k.startswith("_")]
async def get_all(self) -> dict[str, Any]:
"""
Get all user key-value pairs.
Returns:
Dictionary of all key-value pairs (excludes internal keys)
"""
all_data = await self._storage.get_all()
return {k: v for k, v in all_data.items() if not k.startswith("_")}
async def clear(self) -> int:
"""
Clear all user keys (preserves internal state).
Returns:
Number of keys deleted
"""
# Save internal state
task_state = await self._storage.get(_TASK_STATE_KEY)
scratchpad = await self._storage.get(_SCRATCHPAD_KEY)
metadata = await self._storage.get(_METADATA_KEY)
count = await self._storage.clear()
# Restore internal state
if metadata is not None:
await self._storage.set(_METADATA_KEY, metadata)
if task_state is not None:
await self._storage.set(_TASK_STATE_KEY, task_state)
if scratchpad is not None:
await self._storage.set(_SCRATCHPAD_KEY, scratchpad)
# Adjust count for preserved keys
preserved = sum(1 for x in [task_state, scratchpad, metadata] if x is not None)
return max(0, count - preserved)
# =========================================================================
# Task State Operations
# =========================================================================
async def set_task_state(self, state: TaskState) -> None:
"""
Set the current task state.
Args:
state: The task state to store
"""
state.updated_at = datetime.now(UTC)
await self._storage.set(_TASK_STATE_KEY, asdict(state))
async def get_task_state(self) -> TaskState | None:
"""
Get the current task state.
Returns:
The current TaskState or None if not set
"""
data = await self._storage.get(_TASK_STATE_KEY)
if data is None:
return None
# Convert datetime strings back to datetime objects
if isinstance(data.get("started_at"), str):
data["started_at"] = datetime.fromisoformat(data["started_at"])
if isinstance(data.get("updated_at"), str):
data["updated_at"] = datetime.fromisoformat(data["updated_at"])
return TaskState(**data)
async def update_task_progress(
self,
current_step: int | None = None,
progress_percent: float | None = None,
status: str | None = None,
) -> TaskState | None:
"""
Update task progress fields.
Args:
current_step: New current step number
progress_percent: New progress percentage (0.0 to 100.0)
status: New status string
Returns:
Updated TaskState or None if no task state exists
"""
state = await self.get_task_state()
if state is None:
return None
if current_step is not None:
state.current_step = current_step
if progress_percent is not None:
state.progress_percent = min(100.0, max(0.0, progress_percent))
if status is not None:
state.status = status
await self.set_task_state(state)
return state
# =========================================================================
# Scratchpad Operations
# =========================================================================
async def append_scratchpad(self, content: str) -> None:
"""
Append content to the scratchpad.
Args:
content: Text to append
"""
settings = get_memory_settings()
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
# Check capacity
if len(entries) >= settings.working_memory_max_items_per_session:
# Remove oldest entries
entries = entries[-(settings.working_memory_max_items_per_session - 1) :]
entry = {
"content": content,
"timestamp": datetime.now(UTC).isoformat(),
}
entries.append(entry)
await self._storage.set(_SCRATCHPAD_KEY, entries)
async def get_scratchpad(self) -> list[str]:
"""
Get all scratchpad entries.
Returns:
List of scratchpad content strings (ordered by time)
"""
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
return [e["content"] for e in entries]
async def get_scratchpad_with_timestamps(self) -> list[dict[str, Any]]:
"""
Get all scratchpad entries with timestamps.
Returns:
List of dicts with 'content' and 'timestamp' keys
"""
return await self._storage.get(_SCRATCHPAD_KEY) or []
async def clear_scratchpad(self) -> int:
"""
Clear the scratchpad.
Returns:
Number of entries cleared
"""
entries = await self._storage.get(_SCRATCHPAD_KEY) or []
count = len(entries)
await self._storage.set(_SCRATCHPAD_KEY, [])
return count
# =========================================================================
# Checkpoint Operations
# =========================================================================
async def create_checkpoint(self, description: str = "") -> str:
"""
Create a checkpoint of current state.
Args:
description: Optional description of the checkpoint
Returns:
Checkpoint ID for later restoration
"""
checkpoint_id = str(uuid.uuid4())[:8]
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
# Capture all current state
all_data = await self._storage.get_all()
checkpoint = {
"id": checkpoint_id,
"description": description,
"created_at": datetime.now(UTC).isoformat(),
"data": all_data,
}
await self._storage.set(checkpoint_key, checkpoint)
logger.debug(f"Created checkpoint {checkpoint_id}")
return checkpoint_id
async def restore_checkpoint(self, checkpoint_id: str) -> None:
"""
Restore state from a checkpoint.
Args:
checkpoint_id: ID of the checkpoint to restore
Raises:
MemoryNotFoundError: If checkpoint not found
"""
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
checkpoint = await self._storage.get(checkpoint_key)
if checkpoint is None:
raise MemoryNotFoundError(f"Checkpoint {checkpoint_id} not found")
# Clear current state
await self._storage.clear()
# Restore all data from checkpoint
for key, value in checkpoint["data"].items():
await self._storage.set(key, value)
# Keep the checkpoint itself
await self._storage.set(checkpoint_key, checkpoint)
logger.debug(f"Restored checkpoint {checkpoint_id}")
async def list_checkpoints(self) -> list[dict[str, Any]]:
"""
List all available checkpoints.
Returns:
List of checkpoint metadata (id, description, created_at)
"""
checkpoint_keys = await self._storage.list_keys(f"{_CHECKPOINT_PREFIX}*")
checkpoints = []
for key in checkpoint_keys:
data = await self._storage.get(key)
if data:
checkpoints.append(
{
"id": data["id"],
"description": data["description"],
"created_at": data["created_at"],
}
)
# Sort by creation time
checkpoints.sort(key=lambda x: x["created_at"])
return checkpoints
async def delete_checkpoint(self, checkpoint_id: str) -> bool:
"""
Delete a checkpoint.
Args:
checkpoint_id: ID of the checkpoint to delete
Returns:
True if checkpoint existed
"""
checkpoint_key = f"{_CHECKPOINT_PREFIX}{checkpoint_id}"
return await self._storage.delete(checkpoint_key)
# =========================================================================
# Health and Lifecycle
# =========================================================================
async def is_healthy(self) -> bool:
"""Check if the working memory storage is healthy."""
return await self._storage.is_healthy()
async def close(self) -> None:
"""Close the working memory storage."""
if self._storage:
await self._storage.close()
async def get_stats(self) -> dict[str, Any]:
"""
Get working memory statistics.
Returns:
Dictionary with stats about current state
"""
all_keys = await self._storage.list_keys("*")
user_keys = [k for k in all_keys if not k.startswith("_")]
checkpoint_keys = [k for k in all_keys if k.startswith(_CHECKPOINT_PREFIX)]
scratchpad = await self._storage.get(_SCRATCHPAD_KEY) or []
return {
"scope_type": self._scope.scope_type.value,
"scope_id": self._scope.scope_id,
"using_fallback": self._using_fallback,
"total_keys": len(all_keys),
"user_keys": len(user_keys),
"checkpoint_count": len(checkpoint_keys),
"scratchpad_entries": len(scratchpad),
"has_task_state": await self._storage.exists(_TASK_STATE_KEY),
}

View File

@@ -0,0 +1,406 @@
# app/services/memory/working/storage.py
"""
Working Memory Storage Backends.
Provides abstract storage interface and implementations:
- RedisStorage: Primary storage using Redis with connection pooling
- InMemoryStorage: Fallback storage when Redis is unavailable
"""
import asyncio
import fnmatch
import json
import logging
from abc import ABC, abstractmethod
from datetime import UTC, datetime, timedelta
from typing import Any
from app.services.memory.config import get_memory_settings
from app.services.memory.exceptions import (
MemoryConnectionError,
MemoryStorageError,
)
logger = logging.getLogger(__name__)
class WorkingMemoryStorage(ABC):
"""Abstract base class for working memory storage backends."""
@abstractmethod
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""Store a value with optional TTL."""
...
@abstractmethod
async def get(self, key: str) -> Any | None:
"""Get a value by key, returns None if not found or expired."""
...
@abstractmethod
async def delete(self, key: str) -> bool:
"""Delete a key, returns True if existed."""
...
@abstractmethod
async def exists(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
...
@abstractmethod
async def list_keys(self, pattern: str = "*") -> list[str]:
"""List all keys matching a pattern."""
...
@abstractmethod
async def get_all(self) -> dict[str, Any]:
"""Get all key-value pairs."""
...
@abstractmethod
async def clear(self) -> int:
"""Clear all keys, returns count of deleted keys."""
...
@abstractmethod
async def is_healthy(self) -> bool:
"""Check if the storage backend is healthy."""
...
@abstractmethod
async def close(self) -> None:
"""Close the storage connection."""
...
class InMemoryStorage(WorkingMemoryStorage):
"""
In-memory storage backend for working memory.
Used as fallback when Redis is unavailable. Data is not persisted
across restarts and is not shared between processes.
"""
def __init__(self, max_keys: int = 10000) -> None:
"""Initialize in-memory storage."""
self._data: dict[str, Any] = {}
self._expirations: dict[str, datetime] = {}
self._max_keys = max_keys
self._lock = asyncio.Lock()
def _is_expired(self, key: str) -> bool:
"""Check if a key has expired."""
if key not in self._expirations:
return False
return datetime.now(UTC) > self._expirations[key]
def _cleanup_expired(self) -> None:
"""Remove all expired keys."""
now = datetime.now(UTC)
expired_keys = [
key for key, exp_time in self._expirations.items() if now > exp_time
]
for key in expired_keys:
self._data.pop(key, None)
self._expirations.pop(key, None)
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""Store a value with optional TTL."""
async with self._lock:
# Cleanup expired keys periodically
if len(self._data) % 100 == 0:
self._cleanup_expired()
# Check capacity
if key not in self._data and len(self._data) >= self._max_keys:
# Evict expired keys first
self._cleanup_expired()
if len(self._data) >= self._max_keys:
raise MemoryStorageError(
f"Working memory capacity exceeded: {self._max_keys} keys"
)
self._data[key] = value
if ttl_seconds is not None:
self._expirations[key] = datetime.now(UTC) + timedelta(
seconds=ttl_seconds
)
elif key in self._expirations:
# Remove existing expiration if no TTL specified
del self._expirations[key]
async def get(self, key: str) -> Any | None:
"""Get a value by key."""
async with self._lock:
if key not in self._data:
return None
if self._is_expired(key):
del self._data[key]
del self._expirations[key]
return None
return self._data[key]
async def delete(self, key: str) -> bool:
"""Delete a key."""
async with self._lock:
existed = key in self._data
self._data.pop(key, None)
self._expirations.pop(key, None)
return existed
async def exists(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
async with self._lock:
if key not in self._data:
return False
if self._is_expired(key):
del self._data[key]
del self._expirations[key]
return False
return True
async def list_keys(self, pattern: str = "*") -> list[str]:
"""List all keys matching a pattern."""
async with self._lock:
self._cleanup_expired()
if pattern == "*":
return list(self._data.keys())
return [key for key in self._data.keys() if fnmatch.fnmatch(key, pattern)]
async def get_all(self) -> dict[str, Any]:
"""Get all key-value pairs."""
async with self._lock:
self._cleanup_expired()
return dict(self._data)
async def clear(self) -> int:
"""Clear all keys."""
async with self._lock:
count = len(self._data)
self._data.clear()
self._expirations.clear()
return count
async def is_healthy(self) -> bool:
"""In-memory storage is always healthy."""
return True
async def close(self) -> None:
"""No cleanup needed for in-memory storage."""
class RedisStorage(WorkingMemoryStorage):
"""
Redis storage backend for working memory.
Primary storage with connection pooling, automatic reconnection,
and proper serialization of Python objects.
"""
def __init__(
self,
key_prefix: str = "",
connection_timeout: float = 5.0,
socket_timeout: float = 5.0,
) -> None:
"""
Initialize Redis storage.
Args:
key_prefix: Prefix for all keys (e.g., "session:abc123:")
connection_timeout: Timeout for establishing connections
socket_timeout: Timeout for socket operations
"""
self._key_prefix = key_prefix
self._connection_timeout = connection_timeout
self._socket_timeout = socket_timeout
self._redis: Any = None
self._lock = asyncio.Lock()
def _make_key(self, key: str) -> str:
"""Add prefix to key."""
return f"{self._key_prefix}{key}"
def _strip_prefix(self, key: str) -> str:
"""Remove prefix from key."""
if key.startswith(self._key_prefix):
return key[len(self._key_prefix) :]
return key
def _serialize(self, value: Any) -> str:
"""Serialize a Python value to JSON string."""
return json.dumps(value, default=str)
def _deserialize(self, data: str | bytes | None) -> Any | None:
"""Deserialize a JSON string to Python value."""
if data is None:
return None
if isinstance(data, bytes):
data = data.decode("utf-8")
return json.loads(data)
async def _get_client(self) -> Any:
"""Get or create Redis client."""
if self._redis is not None:
return self._redis
async with self._lock:
if self._redis is not None:
return self._redis
try:
import redis.asyncio as aioredis
except ImportError as e:
raise MemoryConnectionError(
"redis package not installed. Install with: pip install redis"
) from e
settings = get_memory_settings()
redis_url = settings.redis_url
try:
self._redis = await aioredis.from_url(
redis_url,
encoding="utf-8",
decode_responses=True,
socket_connect_timeout=self._connection_timeout,
socket_timeout=self._socket_timeout,
)
# Test connection
await self._redis.ping()
logger.info("Connected to Redis for working memory")
return self._redis
except Exception as e:
self._redis = None
raise MemoryConnectionError(f"Failed to connect to Redis: {e}") from e
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""Store a value with optional TTL."""
try:
client = await self._get_client()
full_key = self._make_key(key)
serialized = self._serialize(value)
if ttl_seconds is not None:
await client.setex(full_key, ttl_seconds, serialized)
else:
await client.set(full_key, serialized)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to set key {key}: {e}") from e
async def get(self, key: str) -> Any | None:
"""Get a value by key."""
try:
client = await self._get_client()
full_key = self._make_key(key)
data = await client.get(full_key)
return self._deserialize(data)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to get key {key}: {e}") from e
async def delete(self, key: str) -> bool:
"""Delete a key."""
try:
client = await self._get_client()
full_key = self._make_key(key)
result = await client.delete(full_key)
return bool(result)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to delete key {key}: {e}") from e
async def exists(self, key: str) -> bool:
"""Check if a key exists."""
try:
client = await self._get_client()
full_key = self._make_key(key)
result = await client.exists(full_key)
return bool(result)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to check key {key}: {e}") from e
async def list_keys(self, pattern: str = "*") -> list[str]:
"""List all keys matching a pattern."""
try:
client = await self._get_client()
full_pattern = self._make_key(pattern)
keys = await client.keys(full_pattern)
return [self._strip_prefix(key) for key in keys]
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to list keys: {e}") from e
async def get_all(self) -> dict[str, Any]:
"""Get all key-value pairs."""
try:
client = await self._get_client()
full_pattern = self._make_key("*")
keys = await client.keys(full_pattern)
if not keys:
return {}
values = await client.mget(*keys)
result = {}
for key, value in zip(keys, values, strict=False):
stripped_key = self._strip_prefix(key)
result[stripped_key] = self._deserialize(value)
return result
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to get all keys: {e}") from e
async def clear(self) -> int:
"""Clear all keys with this prefix."""
try:
client = await self._get_client()
full_pattern = self._make_key("*")
keys = await client.keys(full_pattern)
if not keys:
return 0
return await client.delete(*keys)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to clear keys: {e}") from e
async def is_healthy(self) -> bool:
"""Check if Redis connection is healthy."""
try:
client = await self._get_client()
await client.ping()
return True
except Exception:
return False
async def close(self) -> None:
"""Close the Redis connection."""
if self._redis is not None:
await self._redis.close()
self._redis = None

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/working/__init__.py
"""Unit tests for working memory implementation."""

View File

@@ -0,0 +1,391 @@
# tests/unit/services/memory/working/test_memory.py
"""Unit tests for WorkingMemory class."""
import pytest
import pytest_asyncio
from app.services.memory.exceptions import MemoryNotFoundError
from app.services.memory.types import ScopeContext, ScopeLevel, TaskState
from app.services.memory.working.memory import WorkingMemory
from app.services.memory.working.storage import InMemoryStorage
@pytest.fixture
def scope() -> ScopeContext:
"""Create a test scope."""
return ScopeContext(
scope_type=ScopeLevel.SESSION,
scope_id="test-session-123",
)
@pytest.fixture
def storage() -> InMemoryStorage:
"""Create a test storage backend."""
return InMemoryStorage(max_keys=1000)
@pytest_asyncio.fixture
async def memory(scope: ScopeContext, storage: InMemoryStorage) -> WorkingMemory:
"""Create a WorkingMemory instance for testing."""
wm = WorkingMemory(scope=scope, storage=storage)
await wm._initialize()
return wm
class TestWorkingMemoryBasicOperations:
"""Tests for basic key-value operations."""
@pytest.mark.asyncio
async def test_set_and_get(self, memory: WorkingMemory) -> None:
"""Test basic set and get."""
await memory.set("key1", "value1")
result = await memory.get("key1")
assert result == "value1"
@pytest.mark.asyncio
async def test_get_with_default(self, memory: WorkingMemory) -> None:
"""Test get with default value."""
result = await memory.get("nonexistent", default="fallback")
assert result == "fallback"
@pytest.mark.asyncio
async def test_delete(self, memory: WorkingMemory) -> None:
"""Test delete operation."""
await memory.set("key1", "value1")
result = await memory.delete("key1")
assert result is True
assert await memory.exists("key1") is False
@pytest.mark.asyncio
async def test_exists(self, memory: WorkingMemory) -> None:
"""Test exists check."""
await memory.set("key1", "value1")
assert await memory.exists("key1") is True
assert await memory.exists("nonexistent") is False
@pytest.mark.asyncio
async def test_reserved_key_prefix(self, memory: WorkingMemory) -> None:
"""Test that keys starting with _ are rejected."""
with pytest.raises(ValueError, match="reserved"):
await memory.set("_internal", "value")
@pytest.mark.asyncio
async def test_cannot_delete_internal_keys(self, memory: WorkingMemory) -> None:
"""Test that internal keys cannot be deleted directly."""
with pytest.raises(ValueError, match="internal"):
await memory.delete("_task_state")
class TestWorkingMemoryListAndClear:
"""Tests for list and clear operations."""
@pytest.mark.asyncio
async def test_list_keys(self, memory: WorkingMemory) -> None:
"""Test listing keys."""
await memory.set("key1", "value1")
await memory.set("key2", "value2")
keys = await memory.list_keys()
assert set(keys) == {"key1", "key2"}
@pytest.mark.asyncio
async def test_list_keys_excludes_internal(self, memory: WorkingMemory) -> None:
"""Test that list_keys excludes internal keys."""
await memory.set("user_key", "value")
# Internal keys exist from initialization
keys = await memory.list_keys()
assert all(not k.startswith("_") for k in keys)
@pytest.mark.asyncio
async def test_list_keys_with_pattern(self, memory: WorkingMemory) -> None:
"""Test listing keys with pattern."""
await memory.set("prefix_a", "value1")
await memory.set("prefix_b", "value2")
await memory.set("other", "value3")
keys = await memory.list_keys("prefix_*")
assert set(keys) == {"prefix_a", "prefix_b"}
@pytest.mark.asyncio
async def test_get_all(self, memory: WorkingMemory) -> None:
"""Test getting all key-value pairs."""
await memory.set("key1", "value1")
await memory.set("key2", "value2")
result = await memory.get_all()
assert result == {"key1": "value1", "key2": "value2"}
@pytest.mark.asyncio
async def test_clear_preserves_internal_state(self, memory: WorkingMemory) -> None:
"""Test that clear preserves internal state."""
# Set some user data
await memory.set("user_key", "value")
# Set task state
state = TaskState(
task_id="task-1",
task_type="test",
description="Test task",
)
await memory.set_task_state(state)
# Clear
await memory.clear()
# User data should be gone
assert await memory.exists("user_key") is False
# Task state should be preserved
restored_state = await memory.get_task_state()
assert restored_state is not None
assert restored_state.task_id == "task-1"
class TestWorkingMemoryTaskState:
"""Tests for task state operations."""
@pytest.mark.asyncio
async def test_set_and_get_task_state(self, memory: WorkingMemory) -> None:
"""Test setting and getting task state."""
state = TaskState(
task_id="task-123",
task_type="code_review",
description="Review pull request",
status="in_progress",
current_step=2,
total_steps=5,
progress_percent=40.0,
context={"pr_id": 456},
)
await memory.set_task_state(state)
result = await memory.get_task_state()
assert result is not None
assert result.task_id == "task-123"
assert result.task_type == "code_review"
assert result.status == "in_progress"
assert result.current_step == 2
assert result.progress_percent == 40.0
assert result.context == {"pr_id": 456}
@pytest.mark.asyncio
async def test_get_task_state_none_when_not_set(
self, memory: WorkingMemory
) -> None:
"""Test that get_task_state returns None when not set."""
result = await memory.get_task_state()
assert result is None
@pytest.mark.asyncio
async def test_update_task_progress(self, memory: WorkingMemory) -> None:
"""Test updating task progress."""
state = TaskState(
task_id="task-123",
task_type="test",
description="Test",
current_step=1,
progress_percent=10.0,
status="running",
)
await memory.set_task_state(state)
updated = await memory.update_task_progress(
current_step=3,
progress_percent=60.0,
status="processing",
)
assert updated is not None
assert updated.current_step == 3
assert updated.progress_percent == 60.0
assert updated.status == "processing"
@pytest.mark.asyncio
async def test_update_task_progress_clamps_percent(
self, memory: WorkingMemory
) -> None:
"""Test that progress percent is clamped to 0-100."""
state = TaskState(
task_id="task-123",
task_type="test",
description="Test",
)
await memory.set_task_state(state)
updated = await memory.update_task_progress(progress_percent=150.0)
assert updated is not None
assert updated.progress_percent == 100.0
updated = await memory.update_task_progress(progress_percent=-10.0)
assert updated is not None
assert updated.progress_percent == 0.0
class TestWorkingMemoryScratchpad:
"""Tests for scratchpad operations."""
@pytest.mark.asyncio
async def test_append_and_get_scratchpad(self, memory: WorkingMemory) -> None:
"""Test appending to and getting scratchpad."""
await memory.append_scratchpad("First note")
await memory.append_scratchpad("Second note")
entries = await memory.get_scratchpad()
assert entries == ["First note", "Second note"]
@pytest.mark.asyncio
async def test_get_scratchpad_empty(self, memory: WorkingMemory) -> None:
"""Test getting empty scratchpad."""
entries = await memory.get_scratchpad()
assert entries == []
@pytest.mark.asyncio
async def test_get_scratchpad_with_timestamps(self, memory: WorkingMemory) -> None:
"""Test getting scratchpad with timestamps."""
await memory.append_scratchpad("Test note")
entries = await memory.get_scratchpad_with_timestamps()
assert len(entries) == 1
assert entries[0]["content"] == "Test note"
assert "timestamp" in entries[0]
@pytest.mark.asyncio
async def test_clear_scratchpad(self, memory: WorkingMemory) -> None:
"""Test clearing scratchpad."""
await memory.append_scratchpad("Note 1")
await memory.append_scratchpad("Note 2")
count = await memory.clear_scratchpad()
assert count == 2
entries = await memory.get_scratchpad()
assert entries == []
class TestWorkingMemoryCheckpoints:
"""Tests for checkpoint operations."""
@pytest.mark.asyncio
async def test_create_checkpoint(self, memory: WorkingMemory) -> None:
"""Test creating a checkpoint."""
await memory.set("key1", "value1")
await memory.set("key2", "value2")
checkpoint_id = await memory.create_checkpoint("Test checkpoint")
assert checkpoint_id is not None
assert len(checkpoint_id) == 8 # UUID prefix
@pytest.mark.asyncio
async def test_restore_checkpoint(self, memory: WorkingMemory) -> None:
"""Test restoring from a checkpoint."""
await memory.set("key1", "original")
checkpoint_id = await memory.create_checkpoint()
# Modify state
await memory.set("key1", "modified")
await memory.set("key2", "new")
# Restore
await memory.restore_checkpoint(checkpoint_id)
# Check restoration
assert await memory.get("key1") == "original"
# key2 didn't exist in checkpoint, so it should be gone
# But due to checkpoint being restored with clear, it's gone
@pytest.mark.asyncio
async def test_restore_nonexistent_checkpoint(self, memory: WorkingMemory) -> None:
"""Test restoring from nonexistent checkpoint raises error."""
with pytest.raises(MemoryNotFoundError):
await memory.restore_checkpoint("nonexistent")
@pytest.mark.asyncio
async def test_list_checkpoints(self, memory: WorkingMemory) -> None:
"""Test listing checkpoints."""
cp1 = await memory.create_checkpoint("First")
cp2 = await memory.create_checkpoint("Second")
checkpoints = await memory.list_checkpoints()
assert len(checkpoints) == 2
ids = [cp["id"] for cp in checkpoints]
assert cp1 in ids
assert cp2 in ids
@pytest.mark.asyncio
async def test_delete_checkpoint(self, memory: WorkingMemory) -> None:
"""Test deleting a checkpoint."""
checkpoint_id = await memory.create_checkpoint()
result = await memory.delete_checkpoint(checkpoint_id)
assert result is True
checkpoints = await memory.list_checkpoints()
assert len(checkpoints) == 0
class TestWorkingMemoryScope:
"""Tests for scope handling."""
@pytest.mark.asyncio
async def test_scope_property(
self, memory: WorkingMemory, scope: ScopeContext
) -> None:
"""Test scope property."""
assert memory.scope == scope
@pytest.mark.asyncio
async def test_for_session_factory(self) -> None:
"""Test for_session factory method."""
# This would normally try Redis and fall back to in-memory
# In tests, Redis won't be available, so it uses fallback
wm = await WorkingMemory.for_session(
session_id="session-abc",
project_id="project-123",
agent_instance_id="agent-456",
)
assert wm.scope.scope_type == ScopeLevel.SESSION
assert wm.scope.scope_id == "session-abc"
assert wm.scope.parent is not None
assert wm.scope.parent.scope_type == ScopeLevel.AGENT_INSTANCE
class TestWorkingMemoryHealth:
"""Tests for health and lifecycle."""
@pytest.mark.asyncio
async def test_is_healthy(self, memory: WorkingMemory) -> None:
"""Test health check."""
assert await memory.is_healthy() is True
@pytest.mark.asyncio
async def test_get_stats(self, memory: WorkingMemory) -> None:
"""Test getting stats."""
await memory.set("key1", "value1")
await memory.append_scratchpad("Note")
state = TaskState(task_id="t1", task_type="test", description="Test")
await memory.set_task_state(state)
stats = await memory.get_stats()
assert stats["scope_type"] == "session"
assert stats["scope_id"] == "test-session-123"
assert stats["user_keys"] == 1
assert stats["scratchpad_entries"] == 1
assert stats["has_task_state"] is True
@pytest.mark.asyncio
async def test_is_using_fallback(self, memory: WorkingMemory) -> None:
"""Test fallback detection."""
# In-memory storage is always fallback
assert memory.is_using_fallback is False # Not set in fixture
@pytest.mark.asyncio
async def test_close(self, memory: WorkingMemory) -> None:
"""Test close doesn't error."""
await memory.close() # Should not raise

View File

@@ -0,0 +1,303 @@
# tests/unit/services/memory/working/test_storage.py
"""Unit tests for working memory storage backends."""
import asyncio
import pytest
from app.services.memory.exceptions import MemoryStorageError
from app.services.memory.working.storage import InMemoryStorage
class TestInMemoryStorageBasicOperations:
"""Tests for basic InMemoryStorage operations."""
@pytest.fixture
def storage(self) -> InMemoryStorage:
"""Create a fresh storage instance."""
return InMemoryStorage(max_keys=100)
@pytest.mark.asyncio
async def test_set_and_get(self, storage: InMemoryStorage) -> None:
"""Test basic set and get."""
await storage.set("key1", "value1")
result = await storage.get("key1")
assert result == "value1"
@pytest.mark.asyncio
async def test_get_nonexistent_key(self, storage: InMemoryStorage) -> None:
"""Test getting a key that doesn't exist."""
result = await storage.get("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_set_overwrites_existing(self, storage: InMemoryStorage) -> None:
"""Test that set overwrites existing values."""
await storage.set("key1", "original")
await storage.set("key1", "updated")
result = await storage.get("key1")
assert result == "updated"
@pytest.mark.asyncio
async def test_delete_existing_key(self, storage: InMemoryStorage) -> None:
"""Test deleting an existing key."""
await storage.set("key1", "value1")
result = await storage.delete("key1")
assert result is True
assert await storage.get("key1") is None
@pytest.mark.asyncio
async def test_delete_nonexistent_key(self, storage: InMemoryStorage) -> None:
"""Test deleting a key that doesn't exist."""
result = await storage.delete("nonexistent")
assert result is False
@pytest.mark.asyncio
async def test_exists(self, storage: InMemoryStorage) -> None:
"""Test exists check."""
await storage.set("key1", "value1")
assert await storage.exists("key1") is True
assert await storage.exists("nonexistent") is False
class TestInMemoryStorageTTL:
"""Tests for TTL functionality."""
@pytest.fixture
def storage(self) -> InMemoryStorage:
"""Create a fresh storage instance."""
return InMemoryStorage(max_keys=100)
@pytest.mark.asyncio
async def test_set_with_ttl(self, storage: InMemoryStorage) -> None:
"""Test that TTL is stored correctly."""
await storage.set("key1", "value1", ttl_seconds=10)
# Key should exist immediately
assert await storage.exists("key1") is True
@pytest.mark.asyncio
async def test_ttl_expiration(self, storage: InMemoryStorage) -> None:
"""Test that expired keys return None."""
await storage.set("key1", "value1", ttl_seconds=1)
# Key exists initially
assert await storage.get("key1") == "value1"
# Wait for expiration
await asyncio.sleep(1.1)
# Key should be expired
assert await storage.get("key1") is None
assert await storage.exists("key1") is False
@pytest.mark.asyncio
async def test_remove_ttl_on_update(self, storage: InMemoryStorage) -> None:
"""Test that updating without TTL removes expiration."""
await storage.set("key1", "value1", ttl_seconds=1)
await storage.set("key1", "value2") # No TTL
await asyncio.sleep(1.1)
# Key should still exist (TTL removed)
assert await storage.get("key1") == "value2"
class TestInMemoryStorageListAndClear:
"""Tests for list and clear operations."""
@pytest.fixture
def storage(self) -> InMemoryStorage:
"""Create a fresh storage instance."""
return InMemoryStorage(max_keys=100)
@pytest.mark.asyncio
async def test_list_keys_all(self, storage: InMemoryStorage) -> None:
"""Test listing all keys."""
await storage.set("key1", "value1")
await storage.set("key2", "value2")
await storage.set("other", "value3")
keys = await storage.list_keys()
assert set(keys) == {"key1", "key2", "other"}
@pytest.mark.asyncio
async def test_list_keys_with_pattern(self, storage: InMemoryStorage) -> None:
"""Test listing keys with pattern."""
await storage.set("key1", "value1")
await storage.set("key2", "value2")
await storage.set("other", "value3")
keys = await storage.list_keys("key*")
assert set(keys) == {"key1", "key2"}
@pytest.mark.asyncio
async def test_get_all(self, storage: InMemoryStorage) -> None:
"""Test getting all key-value pairs."""
await storage.set("key1", "value1")
await storage.set("key2", "value2")
result = await storage.get_all()
assert result == {"key1": "value1", "key2": "value2"}
@pytest.mark.asyncio
async def test_clear(self, storage: InMemoryStorage) -> None:
"""Test clearing all keys."""
await storage.set("key1", "value1")
await storage.set("key2", "value2")
count = await storage.clear()
assert count == 2
assert await storage.get_all() == {}
class TestInMemoryStorageCapacity:
"""Tests for capacity limits."""
@pytest.mark.asyncio
async def test_capacity_limit_exceeded(self) -> None:
"""Test that exceeding capacity raises error."""
storage = InMemoryStorage(max_keys=2)
await storage.set("key1", "value1")
await storage.set("key2", "value2")
with pytest.raises(MemoryStorageError, match="capacity exceeded"):
await storage.set("key3", "value3")
@pytest.mark.asyncio
async def test_update_existing_key_within_capacity(self) -> None:
"""Test that updating existing key doesn't count against capacity."""
storage = InMemoryStorage(max_keys=2)
await storage.set("key1", "value1")
await storage.set("key2", "value2")
await storage.set("key1", "updated") # Should succeed
assert await storage.get("key1") == "updated"
@pytest.mark.asyncio
async def test_expired_keys_freed_for_capacity(self) -> None:
"""Test that expired keys are cleaned up for capacity."""
storage = InMemoryStorage(max_keys=2)
await storage.set("key1", "value1", ttl_seconds=1)
await storage.set("key2", "value2")
await asyncio.sleep(1.1)
# Should succeed because key1 is expired and will be cleaned
await storage.set("key3", "value3")
assert await storage.get("key3") == "value3"
class TestInMemoryStorageDataTypes:
"""Tests for different data types."""
@pytest.fixture
def storage(self) -> InMemoryStorage:
"""Create a fresh storage instance."""
return InMemoryStorage(max_keys=100)
@pytest.mark.asyncio
async def test_store_dict(self, storage: InMemoryStorage) -> None:
"""Test storing dict values."""
data = {"nested": {"key": "value"}, "list": [1, 2, 3]}
await storage.set("dict_key", data)
result = await storage.get("dict_key")
assert result == data
@pytest.mark.asyncio
async def test_store_list(self, storage: InMemoryStorage) -> None:
"""Test storing list values."""
data = [1, 2, {"nested": "dict"}]
await storage.set("list_key", data)
result = await storage.get("list_key")
assert result == data
@pytest.mark.asyncio
async def test_store_numbers(self, storage: InMemoryStorage) -> None:
"""Test storing numeric values."""
await storage.set("int_key", 42)
await storage.set("float_key", 3.14)
assert await storage.get("int_key") == 42
assert await storage.get("float_key") == 3.14
@pytest.mark.asyncio
async def test_store_boolean(self, storage: InMemoryStorage) -> None:
"""Test storing boolean values."""
await storage.set("true_key", True)
await storage.set("false_key", False)
assert await storage.get("true_key") is True
assert await storage.get("false_key") is False
@pytest.mark.asyncio
async def test_store_none(self, storage: InMemoryStorage) -> None:
"""Test storing None value."""
await storage.set("none_key", None)
# Note: None is stored, but get returns None for both missing and None values
# Use exists to distinguish
assert await storage.exists("none_key") is True
class TestInMemoryStorageHealth:
"""Tests for health and lifecycle."""
@pytest.mark.asyncio
async def test_is_healthy(self) -> None:
"""Test health check."""
storage = InMemoryStorage()
assert await storage.is_healthy() is True
@pytest.mark.asyncio
async def test_close(self) -> None:
"""Test close is no-op but doesn't error."""
storage = InMemoryStorage()
await storage.close() # Should not raise
class TestInMemoryStorageConcurrency:
"""Tests for concurrent access."""
@pytest.mark.asyncio
async def test_concurrent_writes(self) -> None:
"""Test concurrent write operations don't corrupt data."""
storage = InMemoryStorage(max_keys=1000)
async def write_batch(prefix: str, count: int) -> None:
for i in range(count):
await storage.set(f"{prefix}_{i}", f"value_{i}")
# Run concurrent writes
await asyncio.gather(
write_batch("a", 100),
write_batch("b", 100),
write_batch("c", 100),
)
# Verify all writes succeeded
keys = await storage.list_keys()
assert len(keys) == 300
@pytest.mark.asyncio
async def test_concurrent_read_write(self) -> None:
"""Test concurrent read and write operations."""
storage = InMemoryStorage()
await storage.set("key", 0)
async def increment() -> None:
for _ in range(100):
val = await storage.get("key") or 0
await storage.set("key", val + 1)
# Run concurrent increments
await asyncio.gather(
increment(),
increment(),
)
# Final value depends on interleaving
# Just verify we don't crash and value is positive
result = await storage.get("key")
assert result > 0