# app/services/memory/working/storage.py """ Working Memory Storage Backends. Provides abstract storage interface and implementations: - RedisStorage: Primary storage using Redis with connection pooling - InMemoryStorage: Fallback storage when Redis is unavailable """ import asyncio import fnmatch import json import logging from abc import ABC, abstractmethod from datetime import UTC, datetime, timedelta from typing import Any from app.services.memory.config import get_memory_settings from app.services.memory.exceptions import ( MemoryConnectionError, MemoryStorageError, ) logger = logging.getLogger(__name__) class WorkingMemoryStorage(ABC): """Abstract base class for working memory storage backends.""" @abstractmethod async def set( self, key: str, value: Any, ttl_seconds: int | None = None, ) -> None: """Store a value with optional TTL.""" ... @abstractmethod async def get(self, key: str) -> Any | None: """Get a value by key, returns None if not found or expired.""" ... @abstractmethod async def delete(self, key: str) -> bool: """Delete a key, returns True if existed.""" ... @abstractmethod async def exists(self, key: str) -> bool: """Check if a key exists and is not expired.""" ... @abstractmethod async def list_keys(self, pattern: str = "*") -> list[str]: """List all keys matching a pattern.""" ... @abstractmethod async def get_all(self) -> dict[str, Any]: """Get all key-value pairs.""" ... @abstractmethod async def clear(self) -> int: """Clear all keys, returns count of deleted keys.""" ... @abstractmethod async def is_healthy(self) -> bool: """Check if the storage backend is healthy.""" ... @abstractmethod async def close(self) -> None: """Close the storage connection.""" ... class InMemoryStorage(WorkingMemoryStorage): """ In-memory storage backend for working memory. Used as fallback when Redis is unavailable. Data is not persisted across restarts and is not shared between processes. """ def __init__(self, max_keys: int = 10000) -> None: """Initialize in-memory storage.""" self._data: dict[str, Any] = {} self._expirations: dict[str, datetime] = {} self._max_keys = max_keys self._lock = asyncio.Lock() def _is_expired(self, key: str) -> bool: """Check if a key has expired.""" if key not in self._expirations: return False return datetime.now(UTC) > self._expirations[key] def _cleanup_expired(self) -> None: """Remove all expired keys.""" now = datetime.now(UTC) expired_keys = [ key for key, exp_time in self._expirations.items() if now > exp_time ] for key in expired_keys: self._data.pop(key, None) self._expirations.pop(key, None) async def set( self, key: str, value: Any, ttl_seconds: int | None = None, ) -> None: """Store a value with optional TTL.""" async with self._lock: # Cleanup expired keys periodically if len(self._data) % 100 == 0: self._cleanup_expired() # Check capacity if key not in self._data and len(self._data) >= self._max_keys: # Evict expired keys first self._cleanup_expired() if len(self._data) >= self._max_keys: raise MemoryStorageError( f"Working memory capacity exceeded: {self._max_keys} keys" ) self._data[key] = value if ttl_seconds is not None: self._expirations[key] = datetime.now(UTC) + timedelta( seconds=ttl_seconds ) elif key in self._expirations: # Remove existing expiration if no TTL specified del self._expirations[key] async def get(self, key: str) -> Any | None: """Get a value by key.""" async with self._lock: if key not in self._data: return None if self._is_expired(key): del self._data[key] del self._expirations[key] return None return self._data[key] async def delete(self, key: str) -> bool: """Delete a key.""" async with self._lock: existed = key in self._data self._data.pop(key, None) self._expirations.pop(key, None) return existed async def exists(self, key: str) -> bool: """Check if a key exists and is not expired.""" async with self._lock: if key not in self._data: return False if self._is_expired(key): del self._data[key] del self._expirations[key] return False return True async def list_keys(self, pattern: str = "*") -> list[str]: """List all keys matching a pattern.""" async with self._lock: self._cleanup_expired() if pattern == "*": return list(self._data.keys()) return [key for key in self._data.keys() if fnmatch.fnmatch(key, pattern)] async def get_all(self) -> dict[str, Any]: """Get all key-value pairs.""" async with self._lock: self._cleanup_expired() return dict(self._data) async def clear(self) -> int: """Clear all keys.""" async with self._lock: count = len(self._data) self._data.clear() self._expirations.clear() return count async def is_healthy(self) -> bool: """In-memory storage is always healthy.""" return True async def close(self) -> None: """No cleanup needed for in-memory storage.""" class RedisStorage(WorkingMemoryStorage): """ Redis storage backend for working memory. Primary storage with connection pooling, automatic reconnection, and proper serialization of Python objects. """ def __init__( self, key_prefix: str = "", connection_timeout: float = 5.0, socket_timeout: float = 5.0, ) -> None: """ Initialize Redis storage. Args: key_prefix: Prefix for all keys (e.g., "session:abc123:") connection_timeout: Timeout for establishing connections socket_timeout: Timeout for socket operations """ self._key_prefix = key_prefix self._connection_timeout = connection_timeout self._socket_timeout = socket_timeout self._redis: Any = None self._lock = asyncio.Lock() def _make_key(self, key: str) -> str: """Add prefix to key.""" return f"{self._key_prefix}{key}" def _strip_prefix(self, key: str) -> str: """Remove prefix from key.""" if key.startswith(self._key_prefix): return key[len(self._key_prefix) :] return key def _serialize(self, value: Any) -> str: """Serialize a Python value to JSON string.""" return json.dumps(value, default=str) def _deserialize(self, data: str | bytes | None) -> Any | None: """Deserialize a JSON string to Python value.""" if data is None: return None if isinstance(data, bytes): data = data.decode("utf-8") return json.loads(data) async def _get_client(self) -> Any: """Get or create Redis client.""" if self._redis is not None: return self._redis async with self._lock: if self._redis is not None: return self._redis try: import redis.asyncio as aioredis except ImportError as e: raise MemoryConnectionError( "redis package not installed. Install with: pip install redis" ) from e settings = get_memory_settings() redis_url = settings.redis_url try: self._redis = await aioredis.from_url( redis_url, encoding="utf-8", decode_responses=True, socket_connect_timeout=self._connection_timeout, socket_timeout=self._socket_timeout, ) # Test connection await self._redis.ping() logger.info("Connected to Redis for working memory") return self._redis except Exception as e: self._redis = None raise MemoryConnectionError(f"Failed to connect to Redis: {e}") from e async def set( self, key: str, value: Any, ttl_seconds: int | None = None, ) -> None: """Store a value with optional TTL.""" try: client = await self._get_client() full_key = self._make_key(key) serialized = self._serialize(value) if ttl_seconds is not None: await client.setex(full_key, ttl_seconds, serialized) else: await client.set(full_key, serialized) except MemoryConnectionError: raise except Exception as e: raise MemoryStorageError(f"Failed to set key {key}: {e}") from e async def get(self, key: str) -> Any | None: """Get a value by key.""" try: client = await self._get_client() full_key = self._make_key(key) data = await client.get(full_key) return self._deserialize(data) except MemoryConnectionError: raise except Exception as e: raise MemoryStorageError(f"Failed to get key {key}: {e}") from e async def delete(self, key: str) -> bool: """Delete a key.""" try: client = await self._get_client() full_key = self._make_key(key) result = await client.delete(full_key) return bool(result) except MemoryConnectionError: raise except Exception as e: raise MemoryStorageError(f"Failed to delete key {key}: {e}") from e async def exists(self, key: str) -> bool: """Check if a key exists.""" try: client = await self._get_client() full_key = self._make_key(key) result = await client.exists(full_key) return bool(result) except MemoryConnectionError: raise except Exception as e: raise MemoryStorageError(f"Failed to check key {key}: {e}") from e async def list_keys(self, pattern: str = "*") -> list[str]: """List all keys matching a pattern.""" try: client = await self._get_client() full_pattern = self._make_key(pattern) keys = await client.keys(full_pattern) return [self._strip_prefix(key) for key in keys] except MemoryConnectionError: raise except Exception as e: raise MemoryStorageError(f"Failed to list keys: {e}") from e async def get_all(self) -> dict[str, Any]: """Get all key-value pairs.""" try: client = await self._get_client() full_pattern = self._make_key("*") keys = await client.keys(full_pattern) if not keys: return {} values = await client.mget(*keys) result = {} for key, value in zip(keys, values, strict=False): stripped_key = self._strip_prefix(key) result[stripped_key] = self._deserialize(value) return result except MemoryConnectionError: raise except Exception as e: raise MemoryStorageError(f"Failed to get all keys: {e}") from e async def clear(self) -> int: """Clear all keys with this prefix.""" try: client = await self._get_client() full_pattern = self._make_key("*") keys = await client.keys(full_pattern) if not keys: return 0 return await client.delete(*keys) except MemoryConnectionError: raise except Exception as e: raise MemoryStorageError(f"Failed to clear keys: {e}") from e async def is_healthy(self) -> bool: """Check if Redis connection is healthy.""" try: client = await self._get_client() await client.ping() return True except Exception: return False async def close(self) -> None: """Close the Redis connection.""" if self._redis is not None: await self._redis.close() self._redis = None