# app/services/memory/scoping/resolver.py """ Scope Resolution. Provides utilities for resolving memory queries across scope hierarchies, implementing inheritance and aggregation of memories from parent scopes. """ import logging from collections.abc import Callable from dataclasses import dataclass, field from typing import Any, TypeVar from app.services.memory.types import ScopeContext, ScopeLevel from .scope import ScopeManager, get_scope_manager logger = logging.getLogger(__name__) T = TypeVar("T") @dataclass class ResolutionResult[T]: """Result of a scope resolution.""" items: list[T] sources: list[ScopeContext] total_from_each: dict[str, int] = field(default_factory=dict) inherited_count: int = 0 own_count: int = 0 @property def total_count(self) -> int: """Get total items from all sources.""" return len(self.items) @dataclass class ResolutionOptions: """Options for scope resolution.""" include_inherited: bool = True max_inheritance_depth: int = 5 limit_per_scope: int = 100 total_limit: int = 500 deduplicate: bool = True deduplicate_key: str | None = None # Field to use for deduplication class ScopeResolver: """ Resolves memory queries across scope hierarchies. Features: - Traverse scope hierarchy for inherited memories - Aggregate results from multiple scope levels - Apply access control policies - Support deduplication across scopes """ def __init__( self, manager: ScopeManager | None = None, ) -> None: """ Initialize the resolver. Args: manager: Scope manager to use (defaults to singleton) """ self._manager = manager or get_scope_manager() def resolve( self, scope: ScopeContext, fetcher: Callable[[ScopeContext, int], list[T]], options: ResolutionOptions | None = None, ) -> ResolutionResult[T]: """ Resolve memories for a scope, including inherited memories. Args: scope: Starting scope fetcher: Function to fetch items for a scope (scope, limit) -> items options: Resolution options Returns: Resolution result with items from all scopes """ opts = options or ResolutionOptions() all_items: list[T] = [] sources: list[ScopeContext] = [] counts: dict[str, int] = {} seen_keys: set[Any] = set() # Collect scopes to query (starting from current, going up to ancestors) scopes_to_query = self._collect_queryable_scopes( scope=scope, max_depth=opts.max_inheritance_depth if opts.include_inherited else 0, ) own_count = 0 inherited_count = 0 remaining_limit = opts.total_limit for i, query_scope in enumerate(scopes_to_query): if remaining_limit <= 0: break # Check access policy policy = self._manager.get_policy(query_scope) if not policy.allows_read(): continue if i > 0 and not policy.allows_inherit(): continue # Fetch items for this scope scope_limit = min(opts.limit_per_scope, remaining_limit) items = fetcher(query_scope, scope_limit) # Apply deduplication if opts.deduplicate and opts.deduplicate_key: items = self._deduplicate(items, opts.deduplicate_key, seen_keys) if items: all_items.extend(items) sources.append(query_scope) key = f"{query_scope.scope_type.value}:{query_scope.scope_id}" counts[key] = len(items) if i == 0: own_count = len(items) else: inherited_count += len(items) remaining_limit -= len(items) logger.debug( f"Resolved {len(all_items)} items from {len(sources)} scopes " f"(own={own_count}, inherited={inherited_count})" ) return ResolutionResult( items=all_items[: opts.total_limit], sources=sources, total_from_each=counts, own_count=own_count, inherited_count=inherited_count, ) def _collect_queryable_scopes( self, scope: ScopeContext, max_depth: int, ) -> list[ScopeContext]: """Collect scopes to query, from current to ancestors.""" scopes: list[ScopeContext] = [scope] if max_depth <= 0: return scopes current = scope.parent depth = 0 while current is not None and depth < max_depth: scopes.append(current) current = current.parent depth += 1 return scopes def _deduplicate( self, items: list[T], key_field: str, seen_keys: set[Any], ) -> list[T]: """Remove duplicate items based on a key field.""" unique: list[T] = [] for item in items: key = getattr(item, key_field, None) if key is None: # If no key, include the item unique.append(item) elif key not in seen_keys: seen_keys.add(key) unique.append(item) return unique def get_visible_scopes( self, scope: ScopeContext, ) -> list[ScopeContext]: """ Get all scopes visible from a given scope. A scope can see itself and all its ancestors (if inheritance allowed). Args: scope: Starting scope Returns: List of visible scopes (from most specific to most general) """ visible = [scope] current = scope.parent while current is not None: policy = self._manager.get_policy(current) if policy.allows_inherit(): visible.append(current) else: break # Stop at first non-inheritable scope current = current.parent return visible def find_write_scope( self, target_level: ScopeLevel, scope: ScopeContext, ) -> ScopeContext | None: """ Find the appropriate scope for writing at a target level. Walks up the hierarchy to find a scope at the target level that allows writing. Args: target_level: Desired scope level scope: Starting scope Returns: Scope to write to, or None if not found/not allowed """ # First check if current scope is at target level if scope.scope_type == target_level: policy = self._manager.get_policy(scope) return scope if policy.allows_write() else None # Check ancestors current = scope.parent while current is not None: if current.scope_type == target_level: policy = self._manager.get_policy(current) return current if policy.allows_write() else None current = current.parent return None def resolve_scope_from_memory( self, memory_type: str, project_id: str | None = None, agent_type_id: str | None = None, agent_instance_id: str | None = None, session_id: str | None = None, ) -> tuple[ScopeContext, ScopeLevel]: """ Resolve the appropriate scope for a memory operation. Different memory types have different scope requirements: - working: Session or Agent Instance - episodic: Agent Instance or Project - semantic: Project or Global - procedural: Agent Type or Project Args: memory_type: Type of memory project_id: Project ID agent_type_id: Agent type ID agent_instance_id: Agent instance ID session_id: Session ID Returns: Tuple of (scope context, recommended level) """ # Build full scope chain scope = self._manager.create_scope_from_ids( project_id=project_id if project_id else None, # type: ignore[arg-type] agent_type_id=agent_type_id if agent_type_id else None, # type: ignore[arg-type] agent_instance_id=agent_instance_id if agent_instance_id else None, # type: ignore[arg-type] session_id=session_id, ) # Determine recommended level based on memory type recommended = self._get_recommended_level(memory_type) return scope, recommended def _get_recommended_level(self, memory_type: str) -> ScopeLevel: """Get recommended scope level for a memory type.""" recommendations = { "working": ScopeLevel.SESSION, "episodic": ScopeLevel.AGENT_INSTANCE, "semantic": ScopeLevel.PROJECT, "procedural": ScopeLevel.AGENT_TYPE, } return recommendations.get(memory_type, ScopeLevel.PROJECT) def validate_write_access( self, scope: ScopeContext, memory_type: str, ) -> bool: """ Validate that writing is allowed for the given scope and memory type. Args: scope: Scope to validate memory_type: Type of memory to write Returns: True if write is allowed """ policy = self._manager.get_policy(scope) if not policy.allows_write(): return False if not policy.allows_memory_type(memory_type): return False return True def get_scope_chain( self, scope: ScopeContext, ) -> list[tuple[ScopeLevel, str]]: """ Get the scope chain as a list of (level, id) tuples. Args: scope: Scope to get chain for Returns: List of (level, id) tuples from root to leaf """ chain: list[tuple[ScopeLevel, str]] = [] # Get full hierarchy hierarchy = scope.get_hierarchy() for ctx in hierarchy: chain.append((ctx.scope_type, ctx.scope_id)) return chain @dataclass class ScopeFilter: """Filter for querying across scopes.""" scope_types: list[ScopeLevel] | None = None project_ids: list[str] | None = None agent_type_ids: list[str] | None = None include_global: bool = True def matches(self, scope: ScopeContext) -> bool: """Check if a scope matches this filter.""" if self.scope_types and scope.scope_type not in self.scope_types: return False if scope.scope_type == ScopeLevel.GLOBAL: return self.include_global if scope.scope_type == ScopeLevel.PROJECT: if self.project_ids and scope.scope_id not in self.project_ids: return False if scope.scope_type == ScopeLevel.AGENT_TYPE: if self.agent_type_ids and scope.scope_id not in self.agent_type_ids: return False return True # Singleton resolver instance _resolver: ScopeResolver | None = None def get_scope_resolver() -> ScopeResolver: """Get the singleton scope resolver instance.""" global _resolver if _resolver is None: _resolver = ScopeResolver() return _resolver