Files
syndarix/backend/app/services/memory/working/storage.py
Felipe Cardoso 4974233169 feat(memory): add working memory implementation (Issue #89)
Implements session-scoped ephemeral memory with:

Storage Backends:
- InMemoryStorage: Thread-safe fallback with TTL support and capacity limits
- RedisStorage: Primary storage with connection pooling and JSON serialization
- Auto-fallback from Redis to in-memory when unavailable

WorkingMemory Class:
- Key-value storage with TTL and reserved key protection
- Task state tracking with progress updates
- Scratchpad for reasoning steps with timestamps
- Checkpoint/snapshot support for recovery
- Factory methods for auto-configured storage

Tests:
- 55 unit tests covering all functionality
- Tests for basic ops, TTL, capacity, concurrency
- Tests for task state, scratchpad, checkpoints

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 01:51:03 +01:00

407 lines
13 KiB
Python

# app/services/memory/working/storage.py
"""
Working Memory Storage Backends.
Provides abstract storage interface and implementations:
- RedisStorage: Primary storage using Redis with connection pooling
- InMemoryStorage: Fallback storage when Redis is unavailable
"""
import asyncio
import fnmatch
import json
import logging
from abc import ABC, abstractmethod
from datetime import UTC, datetime, timedelta
from typing import Any
from app.services.memory.config import get_memory_settings
from app.services.memory.exceptions import (
MemoryConnectionError,
MemoryStorageError,
)
logger = logging.getLogger(__name__)
class WorkingMemoryStorage(ABC):
"""Abstract base class for working memory storage backends."""
@abstractmethod
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""Store a value with optional TTL."""
...
@abstractmethod
async def get(self, key: str) -> Any | None:
"""Get a value by key, returns None if not found or expired."""
...
@abstractmethod
async def delete(self, key: str) -> bool:
"""Delete a key, returns True if existed."""
...
@abstractmethod
async def exists(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
...
@abstractmethod
async def list_keys(self, pattern: str = "*") -> list[str]:
"""List all keys matching a pattern."""
...
@abstractmethod
async def get_all(self) -> dict[str, Any]:
"""Get all key-value pairs."""
...
@abstractmethod
async def clear(self) -> int:
"""Clear all keys, returns count of deleted keys."""
...
@abstractmethod
async def is_healthy(self) -> bool:
"""Check if the storage backend is healthy."""
...
@abstractmethod
async def close(self) -> None:
"""Close the storage connection."""
...
class InMemoryStorage(WorkingMemoryStorage):
"""
In-memory storage backend for working memory.
Used as fallback when Redis is unavailable. Data is not persisted
across restarts and is not shared between processes.
"""
def __init__(self, max_keys: int = 10000) -> None:
"""Initialize in-memory storage."""
self._data: dict[str, Any] = {}
self._expirations: dict[str, datetime] = {}
self._max_keys = max_keys
self._lock = asyncio.Lock()
def _is_expired(self, key: str) -> bool:
"""Check if a key has expired."""
if key not in self._expirations:
return False
return datetime.now(UTC) > self._expirations[key]
def _cleanup_expired(self) -> None:
"""Remove all expired keys."""
now = datetime.now(UTC)
expired_keys = [
key for key, exp_time in self._expirations.items() if now > exp_time
]
for key in expired_keys:
self._data.pop(key, None)
self._expirations.pop(key, None)
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""Store a value with optional TTL."""
async with self._lock:
# Cleanup expired keys periodically
if len(self._data) % 100 == 0:
self._cleanup_expired()
# Check capacity
if key not in self._data and len(self._data) >= self._max_keys:
# Evict expired keys first
self._cleanup_expired()
if len(self._data) >= self._max_keys:
raise MemoryStorageError(
f"Working memory capacity exceeded: {self._max_keys} keys"
)
self._data[key] = value
if ttl_seconds is not None:
self._expirations[key] = datetime.now(UTC) + timedelta(
seconds=ttl_seconds
)
elif key in self._expirations:
# Remove existing expiration if no TTL specified
del self._expirations[key]
async def get(self, key: str) -> Any | None:
"""Get a value by key."""
async with self._lock:
if key not in self._data:
return None
if self._is_expired(key):
del self._data[key]
del self._expirations[key]
return None
return self._data[key]
async def delete(self, key: str) -> bool:
"""Delete a key."""
async with self._lock:
existed = key in self._data
self._data.pop(key, None)
self._expirations.pop(key, None)
return existed
async def exists(self, key: str) -> bool:
"""Check if a key exists and is not expired."""
async with self._lock:
if key not in self._data:
return False
if self._is_expired(key):
del self._data[key]
del self._expirations[key]
return False
return True
async def list_keys(self, pattern: str = "*") -> list[str]:
"""List all keys matching a pattern."""
async with self._lock:
self._cleanup_expired()
if pattern == "*":
return list(self._data.keys())
return [key for key in self._data.keys() if fnmatch.fnmatch(key, pattern)]
async def get_all(self) -> dict[str, Any]:
"""Get all key-value pairs."""
async with self._lock:
self._cleanup_expired()
return dict(self._data)
async def clear(self) -> int:
"""Clear all keys."""
async with self._lock:
count = len(self._data)
self._data.clear()
self._expirations.clear()
return count
async def is_healthy(self) -> bool:
"""In-memory storage is always healthy."""
return True
async def close(self) -> None:
"""No cleanup needed for in-memory storage."""
class RedisStorage(WorkingMemoryStorage):
"""
Redis storage backend for working memory.
Primary storage with connection pooling, automatic reconnection,
and proper serialization of Python objects.
"""
def __init__(
self,
key_prefix: str = "",
connection_timeout: float = 5.0,
socket_timeout: float = 5.0,
) -> None:
"""
Initialize Redis storage.
Args:
key_prefix: Prefix for all keys (e.g., "session:abc123:")
connection_timeout: Timeout for establishing connections
socket_timeout: Timeout for socket operations
"""
self._key_prefix = key_prefix
self._connection_timeout = connection_timeout
self._socket_timeout = socket_timeout
self._redis: Any = None
self._lock = asyncio.Lock()
def _make_key(self, key: str) -> str:
"""Add prefix to key."""
return f"{self._key_prefix}{key}"
def _strip_prefix(self, key: str) -> str:
"""Remove prefix from key."""
if key.startswith(self._key_prefix):
return key[len(self._key_prefix) :]
return key
def _serialize(self, value: Any) -> str:
"""Serialize a Python value to JSON string."""
return json.dumps(value, default=str)
def _deserialize(self, data: str | bytes | None) -> Any | None:
"""Deserialize a JSON string to Python value."""
if data is None:
return None
if isinstance(data, bytes):
data = data.decode("utf-8")
return json.loads(data)
async def _get_client(self) -> Any:
"""Get or create Redis client."""
if self._redis is not None:
return self._redis
async with self._lock:
if self._redis is not None:
return self._redis
try:
import redis.asyncio as aioredis
except ImportError as e:
raise MemoryConnectionError(
"redis package not installed. Install with: pip install redis"
) from e
settings = get_memory_settings()
redis_url = settings.redis_url
try:
self._redis = await aioredis.from_url(
redis_url,
encoding="utf-8",
decode_responses=True,
socket_connect_timeout=self._connection_timeout,
socket_timeout=self._socket_timeout,
)
# Test connection
await self._redis.ping()
logger.info("Connected to Redis for working memory")
return self._redis
except Exception as e:
self._redis = None
raise MemoryConnectionError(f"Failed to connect to Redis: {e}") from e
async def set(
self,
key: str,
value: Any,
ttl_seconds: int | None = None,
) -> None:
"""Store a value with optional TTL."""
try:
client = await self._get_client()
full_key = self._make_key(key)
serialized = self._serialize(value)
if ttl_seconds is not None:
await client.setex(full_key, ttl_seconds, serialized)
else:
await client.set(full_key, serialized)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to set key {key}: {e}") from e
async def get(self, key: str) -> Any | None:
"""Get a value by key."""
try:
client = await self._get_client()
full_key = self._make_key(key)
data = await client.get(full_key)
return self._deserialize(data)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to get key {key}: {e}") from e
async def delete(self, key: str) -> bool:
"""Delete a key."""
try:
client = await self._get_client()
full_key = self._make_key(key)
result = await client.delete(full_key)
return bool(result)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to delete key {key}: {e}") from e
async def exists(self, key: str) -> bool:
"""Check if a key exists."""
try:
client = await self._get_client()
full_key = self._make_key(key)
result = await client.exists(full_key)
return bool(result)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to check key {key}: {e}") from e
async def list_keys(self, pattern: str = "*") -> list[str]:
"""List all keys matching a pattern."""
try:
client = await self._get_client()
full_pattern = self._make_key(pattern)
keys = await client.keys(full_pattern)
return [self._strip_prefix(key) for key in keys]
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to list keys: {e}") from e
async def get_all(self) -> dict[str, Any]:
"""Get all key-value pairs."""
try:
client = await self._get_client()
full_pattern = self._make_key("*")
keys = await client.keys(full_pattern)
if not keys:
return {}
values = await client.mget(*keys)
result = {}
for key, value in zip(keys, values, strict=False):
stripped_key = self._strip_prefix(key)
result[stripped_key] = self._deserialize(value)
return result
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to get all keys: {e}") from e
async def clear(self) -> int:
"""Clear all keys with this prefix."""
try:
client = await self._get_client()
full_pattern = self._make_key("*")
keys = await client.keys(full_pattern)
if not keys:
return 0
return await client.delete(*keys)
except MemoryConnectionError:
raise
except Exception as e:
raise MemoryStorageError(f"Failed to clear keys: {e}") from e
async def is_healthy(self) -> bool:
"""Check if Redis connection is healthy."""
try:
client = await self._get_client()
await client.ping()
return True
except Exception:
return False
async def close(self) -> None:
"""Close the Redis connection."""
if self._redis is not None:
await self._redis.close()
self._redis = None