diff --git a/backend/app/services/memory/scoping/__init__.py b/backend/app/services/memory/scoping/__init__.py index 9684d60..9dffb6f 100644 --- a/backend/app/services/memory/scoping/__init__.py +++ b/backend/app/services/memory/scoping/__init__.py @@ -1,3 +1,4 @@ +# app/services/memory/scoping/__init__.py """ Memory Scoping @@ -5,4 +6,28 @@ Hierarchical scoping for memory with inheritance: Global -> Project -> Agent Type -> Agent Instance -> Session """ -# Will be populated in #93 +from .resolver import ( + ResolutionOptions, + ResolutionResult, + ScopeFilter, + ScopeResolver, + get_scope_resolver, +) +from .scope import ( + ScopeInfo, + ScopeManager, + ScopePolicy, + get_scope_manager, +) + +__all__ = [ + "ResolutionOptions", + "ResolutionResult", + "ScopeFilter", + "ScopeInfo", + "ScopeManager", + "ScopePolicy", + "ScopeResolver", + "get_scope_manager", + "get_scope_resolver", +] diff --git a/backend/app/services/memory/scoping/resolver.py b/backend/app/services/memory/scoping/resolver.py new file mode 100644 index 0000000..b130e7c --- /dev/null +++ b/backend/app/services/memory/scoping/resolver.py @@ -0,0 +1,390 @@ +# app/services/memory/scoping/resolver.py +""" +Scope Resolution. + +Provides utilities for resolving memory queries across scope hierarchies, +implementing inheritance and aggregation of memories from parent scopes. +""" + +import logging +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any, TypeVar + +from app.services.memory.types import ScopeContext, ScopeLevel + +from .scope import ScopeManager, get_scope_manager + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +@dataclass +class ResolutionResult[T]: + """Result of a scope resolution.""" + + items: list[T] + sources: list[ScopeContext] + total_from_each: dict[str, int] = field(default_factory=dict) + inherited_count: int = 0 + own_count: int = 0 + + @property + def total_count(self) -> int: + """Get total items from all sources.""" + return len(self.items) + + +@dataclass +class ResolutionOptions: + """Options for scope resolution.""" + + include_inherited: bool = True + max_inheritance_depth: int = 5 + limit_per_scope: int = 100 + total_limit: int = 500 + deduplicate: bool = True + deduplicate_key: str | None = None # Field to use for deduplication + + +class ScopeResolver: + """ + Resolves memory queries across scope hierarchies. + + Features: + - Traverse scope hierarchy for inherited memories + - Aggregate results from multiple scope levels + - Apply access control policies + - Support deduplication across scopes + """ + + def __init__( + self, + manager: ScopeManager | None = None, + ) -> None: + """ + Initialize the resolver. + + Args: + manager: Scope manager to use (defaults to singleton) + """ + self._manager = manager or get_scope_manager() + + def resolve( + self, + scope: ScopeContext, + fetcher: Callable[[ScopeContext, int], list[T]], + options: ResolutionOptions | None = None, + ) -> ResolutionResult[T]: + """ + Resolve memories for a scope, including inherited memories. + + Args: + scope: Starting scope + fetcher: Function to fetch items for a scope (scope, limit) -> items + options: Resolution options + + Returns: + Resolution result with items from all scopes + """ + opts = options or ResolutionOptions() + + all_items: list[T] = [] + sources: list[ScopeContext] = [] + counts: dict[str, int] = {} + seen_keys: set[Any] = set() + + # Collect scopes to query (starting from current, going up to ancestors) + scopes_to_query = self._collect_queryable_scopes( + scope=scope, + max_depth=opts.max_inheritance_depth if opts.include_inherited else 0, + ) + + own_count = 0 + inherited_count = 0 + remaining_limit = opts.total_limit + + for i, query_scope in enumerate(scopes_to_query): + if remaining_limit <= 0: + break + + # Check access policy + policy = self._manager.get_policy(query_scope) + if not policy.allows_read(): + continue + if i > 0 and not policy.allows_inherit(): + continue + + # Fetch items for this scope + scope_limit = min(opts.limit_per_scope, remaining_limit) + items = fetcher(query_scope, scope_limit) + + # Apply deduplication + if opts.deduplicate and opts.deduplicate_key: + items = self._deduplicate(items, opts.deduplicate_key, seen_keys) + + if items: + all_items.extend(items) + sources.append(query_scope) + key = f"{query_scope.scope_type.value}:{query_scope.scope_id}" + counts[key] = len(items) + + if i == 0: + own_count = len(items) + else: + inherited_count += len(items) + + remaining_limit -= len(items) + + logger.debug( + f"Resolved {len(all_items)} items from {len(sources)} scopes " + f"(own={own_count}, inherited={inherited_count})" + ) + + return ResolutionResult( + items=all_items[: opts.total_limit], + sources=sources, + total_from_each=counts, + own_count=own_count, + inherited_count=inherited_count, + ) + + def _collect_queryable_scopes( + self, + scope: ScopeContext, + max_depth: int, + ) -> list[ScopeContext]: + """Collect scopes to query, from current to ancestors.""" + scopes: list[ScopeContext] = [scope] + + if max_depth <= 0: + return scopes + + current = scope.parent + depth = 0 + + while current is not None and depth < max_depth: + scopes.append(current) + current = current.parent + depth += 1 + + return scopes + + def _deduplicate( + self, + items: list[T], + key_field: str, + seen_keys: set[Any], + ) -> list[T]: + """Remove duplicate items based on a key field.""" + unique: list[T] = [] + + for item in items: + key = getattr(item, key_field, None) + if key is None: + # If no key, include the item + unique.append(item) + elif key not in seen_keys: + seen_keys.add(key) + unique.append(item) + + return unique + + def get_visible_scopes( + self, + scope: ScopeContext, + ) -> list[ScopeContext]: + """ + Get all scopes visible from a given scope. + + A scope can see itself and all its ancestors (if inheritance allowed). + + Args: + scope: Starting scope + + Returns: + List of visible scopes (from most specific to most general) + """ + visible = [scope] + + current = scope.parent + while current is not None: + policy = self._manager.get_policy(current) + if policy.allows_inherit(): + visible.append(current) + else: + break # Stop at first non-inheritable scope + current = current.parent + + return visible + + def find_write_scope( + self, + target_level: ScopeLevel, + scope: ScopeContext, + ) -> ScopeContext | None: + """ + Find the appropriate scope for writing at a target level. + + Walks up the hierarchy to find a scope at the target level + that allows writing. + + Args: + target_level: Desired scope level + scope: Starting scope + + Returns: + Scope to write to, or None if not found/not allowed + """ + # First check if current scope is at target level + if scope.scope_type == target_level: + policy = self._manager.get_policy(scope) + return scope if policy.allows_write() else None + + # Check ancestors + current = scope.parent + while current is not None: + if current.scope_type == target_level: + policy = self._manager.get_policy(current) + return current if policy.allows_write() else None + current = current.parent + + return None + + def resolve_scope_from_memory( + self, + memory_type: str, + project_id: str | None = None, + agent_type_id: str | None = None, + agent_instance_id: str | None = None, + session_id: str | None = None, + ) -> tuple[ScopeContext, ScopeLevel]: + """ + Resolve the appropriate scope for a memory operation. + + Different memory types have different scope requirements: + - working: Session or Agent Instance + - episodic: Agent Instance or Project + - semantic: Project or Global + - procedural: Agent Type or Project + + Args: + memory_type: Type of memory + project_id: Project ID + agent_type_id: Agent type ID + agent_instance_id: Agent instance ID + session_id: Session ID + + Returns: + Tuple of (scope context, recommended level) + """ + # Build full scope chain + scope = self._manager.create_scope_from_ids( + project_id=project_id if project_id else None, # type: ignore[arg-type] + agent_type_id=agent_type_id if agent_type_id else None, # type: ignore[arg-type] + agent_instance_id=agent_instance_id if agent_instance_id else None, # type: ignore[arg-type] + session_id=session_id, + ) + + # Determine recommended level based on memory type + recommended = self._get_recommended_level(memory_type) + + return scope, recommended + + def _get_recommended_level(self, memory_type: str) -> ScopeLevel: + """Get recommended scope level for a memory type.""" + recommendations = { + "working": ScopeLevel.SESSION, + "episodic": ScopeLevel.AGENT_INSTANCE, + "semantic": ScopeLevel.PROJECT, + "procedural": ScopeLevel.AGENT_TYPE, + } + return recommendations.get(memory_type, ScopeLevel.PROJECT) + + def validate_write_access( + self, + scope: ScopeContext, + memory_type: str, + ) -> bool: + """ + Validate that writing is allowed for the given scope and memory type. + + Args: + scope: Scope to validate + memory_type: Type of memory to write + + Returns: + True if write is allowed + """ + policy = self._manager.get_policy(scope) + + if not policy.allows_write(): + return False + + if not policy.allows_memory_type(memory_type): + return False + + return True + + def get_scope_chain( + self, + scope: ScopeContext, + ) -> list[tuple[ScopeLevel, str]]: + """ + Get the scope chain as a list of (level, id) tuples. + + Args: + scope: Scope to get chain for + + Returns: + List of (level, id) tuples from root to leaf + """ + chain: list[tuple[ScopeLevel, str]] = [] + + # Get full hierarchy + hierarchy = scope.get_hierarchy() + for ctx in hierarchy: + chain.append((ctx.scope_type, ctx.scope_id)) + + return chain + + +@dataclass +class ScopeFilter: + """Filter for querying across scopes.""" + + scope_types: list[ScopeLevel] | None = None + project_ids: list[str] | None = None + agent_type_ids: list[str] | None = None + include_global: bool = True + + def matches(self, scope: ScopeContext) -> bool: + """Check if a scope matches this filter.""" + if self.scope_types and scope.scope_type not in self.scope_types: + return False + + if scope.scope_type == ScopeLevel.GLOBAL: + return self.include_global + + if scope.scope_type == ScopeLevel.PROJECT: + if self.project_ids and scope.scope_id not in self.project_ids: + return False + + if scope.scope_type == ScopeLevel.AGENT_TYPE: + if self.agent_type_ids and scope.scope_id not in self.agent_type_ids: + return False + + return True + + +# Singleton resolver instance +_resolver: ScopeResolver | None = None + + +def get_scope_resolver() -> ScopeResolver: + """Get the singleton scope resolver instance.""" + global _resolver + if _resolver is None: + _resolver = ScopeResolver() + return _resolver diff --git a/backend/app/services/memory/scoping/scope.py b/backend/app/services/memory/scoping/scope.py new file mode 100644 index 0000000..d6325f6 --- /dev/null +++ b/backend/app/services/memory/scoping/scope.py @@ -0,0 +1,460 @@ +# app/services/memory/scoping/scope.py +""" +Scope Management. + +Provides utilities for managing memory scopes with hierarchical inheritance: +Global -> Project -> Agent Type -> Agent Instance -> Session +""" + +import logging +from dataclasses import dataclass, field +from typing import Any, ClassVar +from uuid import UUID + +from app.services.memory.types import ScopeContext, ScopeLevel + +logger = logging.getLogger(__name__) + + +@dataclass +class ScopePolicy: + """Access control policy for a scope.""" + + scope_type: ScopeLevel + scope_id: str + can_read: bool = True + can_write: bool = True + can_inherit: bool = True + allowed_memory_types: list[str] = field(default_factory=lambda: ["all"]) + metadata: dict[str, Any] = field(default_factory=dict) + + def allows_read(self) -> bool: + """Check if reading is allowed.""" + return self.can_read + + def allows_write(self) -> bool: + """Check if writing is allowed.""" + return self.can_write + + def allows_inherit(self) -> bool: + """Check if inheritance from parent is allowed.""" + return self.can_inherit + + def allows_memory_type(self, memory_type: str) -> bool: + """Check if a specific memory type is allowed.""" + return ( + "all" in self.allowed_memory_types + or memory_type in self.allowed_memory_types + ) + + +@dataclass +class ScopeInfo: + """Information about a scope including its hierarchy.""" + + context: ScopeContext + policy: ScopePolicy + parent_info: "ScopeInfo | None" = None + child_count: int = 0 + memory_count: int = 0 + + @property + def depth(self) -> int: + """Get the depth of this scope in the hierarchy.""" + count = 0 + current = self.parent_info + while current is not None: + count += 1 + current = current.parent_info + return count + + +class ScopeManager: + """ + Manages memory scopes and their hierarchies. + + Provides: + - Scope creation and validation + - Hierarchy management + - Access control policy management + - Scope inheritance rules + """ + + # Order of scope levels from root to leaf + SCOPE_ORDER: ClassVar[list[ScopeLevel]] = [ + ScopeLevel.GLOBAL, + ScopeLevel.PROJECT, + ScopeLevel.AGENT_TYPE, + ScopeLevel.AGENT_INSTANCE, + ScopeLevel.SESSION, + ] + + def __init__(self) -> None: + """Initialize the scope manager.""" + # In-memory policy cache (would be backed by database in production) + self._policies: dict[str, ScopePolicy] = {} + self._default_policies = self._create_default_policies() + + def _create_default_policies(self) -> dict[ScopeLevel, ScopePolicy]: + """Create default policies for each scope level.""" + return { + ScopeLevel.GLOBAL: ScopePolicy( + scope_type=ScopeLevel.GLOBAL, + scope_id="global", + can_read=True, + can_write=False, # Global writes require special permission + can_inherit=True, + ), + ScopeLevel.PROJECT: ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="default", + can_read=True, + can_write=True, + can_inherit=True, + ), + ScopeLevel.AGENT_TYPE: ScopePolicy( + scope_type=ScopeLevel.AGENT_TYPE, + scope_id="default", + can_read=True, + can_write=True, + can_inherit=True, + ), + ScopeLevel.AGENT_INSTANCE: ScopePolicy( + scope_type=ScopeLevel.AGENT_INSTANCE, + scope_id="default", + can_read=True, + can_write=True, + can_inherit=True, + ), + ScopeLevel.SESSION: ScopePolicy( + scope_type=ScopeLevel.SESSION, + scope_id="default", + can_read=True, + can_write=True, + can_inherit=True, + allowed_memory_types=["working"], # Sessions only allow working memory + ), + } + + def create_scope( + self, + scope_type: ScopeLevel, + scope_id: str, + parent: ScopeContext | None = None, + ) -> ScopeContext: + """ + Create a new scope context. + + Args: + scope_type: Level of the scope + scope_id: Unique identifier within the level + parent: Optional parent scope + + Returns: + Created scope context + + Raises: + ValueError: If scope hierarchy is invalid + """ + # Validate hierarchy + if parent is not None: + self._validate_parent_child(parent.scope_type, scope_type) + + # For non-global scopes without parent, auto-create parent chain + if parent is None and scope_type != ScopeLevel.GLOBAL: + parent = self._create_parent_chain(scope_type, scope_id) + + context = ScopeContext( + scope_type=scope_type, + scope_id=scope_id, + parent=parent, + ) + + logger.debug(f"Created scope: {scope_type.value}:{scope_id}") + return context + + def _validate_parent_child( + self, + parent_type: ScopeLevel, + child_type: ScopeLevel, + ) -> None: + """Validate that parent-child relationship is valid.""" + parent_idx = self.SCOPE_ORDER.index(parent_type) + child_idx = self.SCOPE_ORDER.index(child_type) + + if child_idx <= parent_idx: + raise ValueError( + f"Invalid scope hierarchy: {child_type.value} cannot be child of {parent_type.value}" + ) + + # Allow skipping levels (e.g., PROJECT -> SESSION is valid) + # This enables flexible scope structures + + def _create_parent_chain( + self, + target_type: ScopeLevel, + scope_id: str, + ) -> ScopeContext: + """Create parent scope chain up to target type.""" + target_idx = self.SCOPE_ORDER.index(target_type) + + # Start from global and build chain + current: ScopeContext | None = None + + for i in range(target_idx): + level = self.SCOPE_ORDER[i] + if level == ScopeLevel.GLOBAL: + level_id = "global" + else: + # Use a default ID for intermediate levels + level_id = f"default_{level.value}" + + current = ScopeContext( + scope_type=level, + scope_id=level_id, + parent=current, + ) + + return current # type: ignore[return-value] + + def create_scope_from_ids( + self, + project_id: UUID | None = None, + agent_type_id: UUID | None = None, + agent_instance_id: UUID | None = None, + session_id: str | None = None, + ) -> ScopeContext: + """ + Create a scope context from individual IDs. + + Automatically determines the most specific scope level + based on provided IDs. + + Args: + project_id: Project UUID + agent_type_id: Agent type UUID + agent_instance_id: Agent instance UUID + session_id: Session identifier + + Returns: + Scope context for the most specific level + """ + # Build scope chain from most general to most specific + current: ScopeContext = ScopeContext( + scope_type=ScopeLevel.GLOBAL, + scope_id="global", + parent=None, + ) + + if project_id is not None: + current = ScopeContext( + scope_type=ScopeLevel.PROJECT, + scope_id=str(project_id), + parent=current, + ) + + if agent_type_id is not None: + current = ScopeContext( + scope_type=ScopeLevel.AGENT_TYPE, + scope_id=str(agent_type_id), + parent=current, + ) + + if agent_instance_id is not None: + current = ScopeContext( + scope_type=ScopeLevel.AGENT_INSTANCE, + scope_id=str(agent_instance_id), + parent=current, + ) + + if session_id is not None: + current = ScopeContext( + scope_type=ScopeLevel.SESSION, + scope_id=session_id, + parent=current, + ) + + return current + + def get_policy( + self, + scope: ScopeContext, + ) -> ScopePolicy: + """ + Get the access policy for a scope. + + Args: + scope: Scope to get policy for + + Returns: + Policy for the scope + """ + key = self._scope_key(scope) + + if key in self._policies: + return self._policies[key] + + # Return default policy for the scope level + return self._default_policies.get( + scope.scope_type, + ScopePolicy( + scope_type=scope.scope_type, + scope_id=scope.scope_id, + ), + ) + + def set_policy( + self, + scope: ScopeContext, + policy: ScopePolicy, + ) -> None: + """ + Set the access policy for a scope. + + Args: + scope: Scope to set policy for + policy: Policy to apply + """ + key = self._scope_key(scope) + self._policies[key] = policy + logger.info(f"Set policy for scope {key}") + + def _scope_key(self, scope: ScopeContext) -> str: + """Generate a unique key for a scope.""" + return f"{scope.scope_type.value}:{scope.scope_id}" + + def get_scope_depth(self, scope_type: ScopeLevel) -> int: + """Get the depth of a scope level in the hierarchy.""" + return self.SCOPE_ORDER.index(scope_type) + + def get_parent_level(self, scope_type: ScopeLevel) -> ScopeLevel | None: + """Get the parent scope level for a given level.""" + idx = self.SCOPE_ORDER.index(scope_type) + if idx == 0: + return None + return self.SCOPE_ORDER[idx - 1] + + def get_child_level(self, scope_type: ScopeLevel) -> ScopeLevel | None: + """Get the child scope level for a given level.""" + idx = self.SCOPE_ORDER.index(scope_type) + if idx >= len(self.SCOPE_ORDER) - 1: + return None + return self.SCOPE_ORDER[idx + 1] + + def is_ancestor( + self, + potential_ancestor: ScopeContext, + descendant: ScopeContext, + ) -> bool: + """ + Check if one scope is an ancestor of another. + + Args: + potential_ancestor: Scope to check as ancestor + descendant: Scope to check as descendant + + Returns: + True if ancestor relationship exists + """ + current = descendant.parent + while current is not None: + if ( + current.scope_type == potential_ancestor.scope_type + and current.scope_id == potential_ancestor.scope_id + ): + return True + current = current.parent + return False + + def get_common_ancestor( + self, + scope_a: ScopeContext, + scope_b: ScopeContext, + ) -> ScopeContext | None: + """ + Find the nearest common ancestor of two scopes. + + Args: + scope_a: First scope + scope_b: Second scope + + Returns: + Common ancestor or None if none exists + """ + # Get ancestors of scope_a + ancestors_a: set[str] = set() + current: ScopeContext | None = scope_a + while current is not None: + ancestors_a.add(self._scope_key(current)) + current = current.parent + + # Find first ancestor of scope_b that's in ancestors_a + current = scope_b + while current is not None: + if self._scope_key(current) in ancestors_a: + return current + current = current.parent + + return None + + def can_access( + self, + accessor_scope: ScopeContext, + target_scope: ScopeContext, + operation: str = "read", + ) -> bool: + """ + Check if accessor scope can access target scope. + + Access rules: + - A scope can always access itself + - A scope can access ancestors (if inheritance allowed) + - A scope CANNOT access descendants (privacy) + - Sibling scopes cannot access each other + + Args: + accessor_scope: Scope attempting access + target_scope: Scope being accessed + operation: Type of operation (read/write) + + Returns: + True if access is allowed + """ + # Same scope - always allowed + if ( + accessor_scope.scope_type == target_scope.scope_type + and accessor_scope.scope_id == target_scope.scope_id + ): + policy = self.get_policy(target_scope) + if operation == "write": + return policy.allows_write() + return policy.allows_read() + + # Check if target is ancestor (inheritance) + if self.is_ancestor(target_scope, accessor_scope): + policy = self.get_policy(target_scope) + if not policy.allows_inherit(): + return False + if operation == "write": + return policy.allows_write() + return policy.allows_read() + + # Check if accessor is ancestor of target (downward access) + # This is NOT allowed - parents cannot access children's memories + if self.is_ancestor(accessor_scope, target_scope): + return False + + # Sibling scopes cannot access each other + return False + + +# Singleton manager instance +_manager: ScopeManager | None = None + + +def get_scope_manager() -> ScopeManager: + """Get the singleton scope manager instance.""" + global _manager + if _manager is None: + _manager = ScopeManager() + return _manager diff --git a/backend/tests/unit/services/memory/scoping/__init__.py b/backend/tests/unit/services/memory/scoping/__init__.py new file mode 100644 index 0000000..1744b4a --- /dev/null +++ b/backend/tests/unit/services/memory/scoping/__init__.py @@ -0,0 +1,2 @@ +# tests/unit/services/memory/scoping/__init__.py +"""Unit tests for memory scoping.""" diff --git a/backend/tests/unit/services/memory/scoping/test_resolver.py b/backend/tests/unit/services/memory/scoping/test_resolver.py new file mode 100644 index 0000000..a476170 --- /dev/null +++ b/backend/tests/unit/services/memory/scoping/test_resolver.py @@ -0,0 +1,653 @@ +# tests/unit/services/memory/scoping/test_resolver.py +"""Unit tests for scope resolution.""" + +from dataclasses import dataclass +from uuid import uuid4 + +import pytest + +from app.services.memory.scoping.resolver import ( + ResolutionOptions, + ResolutionResult, + ScopeFilter, + ScopeResolver, + get_scope_resolver, +) +from app.services.memory.scoping.scope import ScopeManager, ScopePolicy +from app.services.memory.types import ScopeContext, ScopeLevel + + +@dataclass +class MockItem: + """Mock item for testing resolution.""" + + id: str + name: str + scope_id: str + + +class TestResolutionResult: + """Tests for ResolutionResult dataclass.""" + + def test_total_count(self) -> None: + """Test total_count property.""" + result = ResolutionResult[MockItem]( + items=[ + MockItem(id="1", name="a", scope_id="s1"), + MockItem(id="2", name="b", scope_id="s2"), + ], + sources=[], + ) + + assert result.total_count == 2 + + def test_empty_result(self) -> None: + """Test empty result.""" + result = ResolutionResult[MockItem]( + items=[], + sources=[], + ) + + assert result.total_count == 0 + assert result.inherited_count == 0 + assert result.own_count == 0 + + +class TestResolutionOptions: + """Tests for ResolutionOptions dataclass.""" + + def test_default_values(self) -> None: + """Test default option values.""" + options = ResolutionOptions() + + assert options.include_inherited is True + assert options.max_inheritance_depth == 5 + assert options.limit_per_scope == 100 + assert options.total_limit == 500 + assert options.deduplicate is True + assert options.deduplicate_key is None + + def test_custom_values(self) -> None: + """Test custom option values.""" + options = ResolutionOptions( + include_inherited=False, + max_inheritance_depth=3, + limit_per_scope=50, + total_limit=200, + deduplicate=False, + deduplicate_key="id", + ) + + assert options.include_inherited is False + assert options.max_inheritance_depth == 3 + assert options.limit_per_scope == 50 + assert options.total_limit == 200 + assert options.deduplicate is False + assert options.deduplicate_key == "id" + + +class TestScopeResolver: + """Tests for ScopeResolver class.""" + + @pytest.fixture + def manager(self) -> ScopeManager: + """Create a scope manager.""" + return ScopeManager() + + @pytest.fixture + def resolver(self, manager: ScopeManager) -> ScopeResolver: + """Create a scope resolver.""" + return ScopeResolver(manager=manager) + + def test_resolve_single_scope( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test resolving from a single scope.""" + scope = manager.create_scope(ScopeLevel.PROJECT, "project-1") + + def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: + if s.scope_id == "project-1": + return [MockItem(id="1", name="item1", scope_id="project-1")] + return [] + + result = resolver.resolve( + scope=scope, + fetcher=fetcher, + options=ResolutionOptions(include_inherited=False), + ) + + assert result.total_count == 1 + assert result.own_count == 1 + assert result.inherited_count == 0 + + def test_resolve_with_inheritance( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test resolving with scope inheritance.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project-1", parent=global_scope + ) + + def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: + if s.scope_id == "project-1": + return [MockItem(id="1", name="project-item", scope_id="project-1")] + elif s.scope_id == "global": + return [MockItem(id="2", name="global-item", scope_id="global")] + return [] + + result = resolver.resolve( + scope=project_scope, + fetcher=fetcher, + options=ResolutionOptions(include_inherited=True), + ) + + assert result.total_count == 2 + assert result.own_count == 1 + assert result.inherited_count == 1 + + def test_resolve_respects_depth_limit( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test that resolution respects max inheritance depth.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + agent_scope = manager.create_scope( + ScopeLevel.AGENT_TYPE, "agent", parent=project_scope + ) + instance_scope = manager.create_scope( + ScopeLevel.AGENT_INSTANCE, "instance", parent=agent_scope + ) + session_scope = manager.create_scope( + ScopeLevel.SESSION, "session", parent=instance_scope + ) + + items_per_scope = { + "session": [MockItem(id="1", name="s", scope_id="session")], + "instance": [MockItem(id="2", name="i", scope_id="instance")], + "agent": [MockItem(id="3", name="a", scope_id="agent")], + "project": [MockItem(id="4", name="p", scope_id="project")], + "global": [MockItem(id="5", name="g", scope_id="global")], + } + + def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: + return items_per_scope.get(s.scope_id, []) + + # Depth 1 should get session + instance + result = resolver.resolve( + scope=session_scope, + fetcher=fetcher, + options=ResolutionOptions(max_inheritance_depth=1), + ) + + assert result.total_count == 2 + + def test_resolve_respects_total_limit( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test that resolution respects total limit.""" + scope = manager.create_scope(ScopeLevel.PROJECT, "project") + + def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: + return [ + MockItem(id=str(i), name=f"item-{i}", scope_id="project") + for i in range(10) + ] + + result = resolver.resolve( + scope=scope, + fetcher=fetcher, + options=ResolutionOptions(total_limit=5), + ) + + assert result.total_count == 5 + + def test_resolve_deduplicates_by_key( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test deduplication by key field.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + + def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: + if s.scope_id == "project": + return [MockItem(id="1", name="project-ver", scope_id="project")] + elif s.scope_id == "global": + # Same ID, should be deduplicated + return [MockItem(id="1", name="global-ver", scope_id="global")] + return [] + + result = resolver.resolve( + scope=project_scope, + fetcher=fetcher, + options=ResolutionOptions(deduplicate=True, deduplicate_key="id"), + ) + + # Should only have the project version (encountered first) + assert result.total_count == 1 + assert result.items[0].name == "project-ver" + + def test_resolve_skips_non_readable_scopes( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test that non-readable scopes are skipped.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + + # Set global as non-readable + manager.set_policy( + global_scope, + ScopePolicy( + scope_type=ScopeLevel.GLOBAL, + scope_id="global", + can_read=False, + ), + ) + + def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: + if s.scope_id == "project": + return [MockItem(id="1", name="project-item", scope_id="project")] + elif s.scope_id == "global": + return [MockItem(id="2", name="global-item", scope_id="global")] + return [] + + result = resolver.resolve( + scope=project_scope, + fetcher=fetcher, + ) + + # Should only have project item + assert result.total_count == 1 + assert result.items[0].scope_id == "project" + + def test_resolve_skips_non_inheritable_scopes( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test that non-inheritable parent scopes stop inheritance.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + + # Set global as non-inheritable + manager.set_policy( + global_scope, + ScopePolicy( + scope_type=ScopeLevel.GLOBAL, + scope_id="global", + can_inherit=False, + ), + ) + + def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: + if s.scope_id == "project": + return [MockItem(id="1", name="project-item", scope_id="project")] + elif s.scope_id == "global": + return [MockItem(id="2", name="global-item", scope_id="global")] + return [] + + result = resolver.resolve( + scope=project_scope, + fetcher=fetcher, + ) + + # Should only have project item + assert result.total_count == 1 + + def test_get_visible_scopes( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test getting visible scopes.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + agent_scope = manager.create_scope( + ScopeLevel.AGENT_TYPE, "agent", parent=project_scope + ) + + visible = resolver.get_visible_scopes(agent_scope) + + assert len(visible) == 3 + assert visible[0].scope_id == "agent" + assert visible[1].scope_id == "project" + assert visible[2].scope_id == "global" + + def test_get_visible_scopes_stops_at_non_inheritable( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test that visible scopes stop at non-inheritable parent.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + agent_scope = manager.create_scope( + ScopeLevel.AGENT_TYPE, "agent", parent=project_scope + ) + + # Make project non-inheritable + manager.set_policy( + project_scope, + ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="project", + can_inherit=False, + ), + ) + + visible = resolver.get_visible_scopes(agent_scope) + + # Should stop at project (exclusive) + assert len(visible) == 1 + assert visible[0].scope_id == "agent" + + def test_find_write_scope_same_level( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test finding write scope at same level.""" + scope = manager.create_scope(ScopeLevel.PROJECT, "project") + + result = resolver.find_write_scope(ScopeLevel.PROJECT, scope) + + assert result is not None + assert result.scope_id == "project" + + def test_find_write_scope_ancestor( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test finding write scope in ancestors.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + agent_scope = manager.create_scope( + ScopeLevel.AGENT_TYPE, "agent", parent=project_scope + ) + + result = resolver.find_write_scope(ScopeLevel.PROJECT, agent_scope) + + assert result is not None + assert result.scope_id == "project" + + def test_find_write_scope_not_found( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test finding write scope when not in hierarchy.""" + scope = manager.create_scope(ScopeLevel.PROJECT, "project") + + # Looking for session level, but we're at project + result = resolver.find_write_scope(ScopeLevel.SESSION, scope) + + assert result is None + + def test_find_write_scope_respects_write_policy( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test that find_write_scope respects write policy.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + + # Make project read-only + manager.set_policy( + project_scope, + ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="project", + can_write=False, + ), + ) + + result = resolver.find_write_scope(ScopeLevel.PROJECT, project_scope) + + assert result is None + + def test_resolve_scope_from_memory_working( + self, + resolver: ScopeResolver, + ) -> None: + """Test resolving scope for working memory.""" + project_id = str(uuid4()) + session_id = "session-123" + + scope, level = resolver.resolve_scope_from_memory( + memory_type="working", + project_id=project_id, + session_id=session_id, + ) + + assert scope.scope_type == ScopeLevel.SESSION + assert level == ScopeLevel.SESSION + + def test_resolve_scope_from_memory_episodic( + self, + resolver: ScopeResolver, + ) -> None: + """Test resolving scope for episodic memory.""" + project_id = str(uuid4()) + agent_instance_id = str(uuid4()) + + scope, level = resolver.resolve_scope_from_memory( + memory_type="episodic", + project_id=project_id, + agent_instance_id=agent_instance_id, + ) + + assert scope.scope_type == ScopeLevel.AGENT_INSTANCE + assert level == ScopeLevel.AGENT_INSTANCE + + def test_resolve_scope_from_memory_semantic( + self, + resolver: ScopeResolver, + ) -> None: + """Test resolving scope for semantic memory.""" + project_id = str(uuid4()) + + scope, level = resolver.resolve_scope_from_memory( + memory_type="semantic", + project_id=project_id, + ) + + assert scope.scope_type == ScopeLevel.PROJECT + assert level == ScopeLevel.PROJECT + + def test_resolve_scope_from_memory_procedural( + self, + resolver: ScopeResolver, + ) -> None: + """Test resolving scope for procedural memory.""" + project_id = str(uuid4()) + agent_type_id = str(uuid4()) + + scope, level = resolver.resolve_scope_from_memory( + memory_type="procedural", + project_id=project_id, + agent_type_id=agent_type_id, + ) + + assert scope.scope_type == ScopeLevel.AGENT_TYPE + assert level == ScopeLevel.AGENT_TYPE + + def test_validate_write_access_allowed( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test write access validation when allowed.""" + scope = manager.create_scope(ScopeLevel.PROJECT, "project") + + assert resolver.validate_write_access(scope, "semantic") is True + + def test_validate_write_access_denied_by_policy( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test write access denied by policy.""" + scope = manager.create_scope(ScopeLevel.PROJECT, "project") + + manager.set_policy( + scope, + ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="project", + can_write=False, + ), + ) + + assert resolver.validate_write_access(scope, "semantic") is False + + def test_validate_write_access_denied_by_memory_type( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test write access denied by memory type restriction.""" + scope = manager.create_scope(ScopeLevel.PROJECT, "project") + + manager.set_policy( + scope, + ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="project", + allowed_memory_types=["episodic"], # Only episodic allowed + ), + ) + + assert resolver.validate_write_access(scope, "semantic") is False + assert resolver.validate_write_access(scope, "episodic") is True + + def test_get_scope_chain( + self, + resolver: ScopeResolver, + manager: ScopeManager, + ) -> None: + """Test getting scope chain.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + agent_scope = manager.create_scope( + ScopeLevel.AGENT_TYPE, "agent", parent=project_scope + ) + + chain = resolver.get_scope_chain(agent_scope) + + assert len(chain) == 3 + assert chain[0] == (ScopeLevel.GLOBAL, "global") + assert chain[1] == (ScopeLevel.PROJECT, "project") + assert chain[2] == (ScopeLevel.AGENT_TYPE, "agent") + + +class TestScopeFilter: + """Tests for ScopeFilter dataclass.""" + + def test_default_values(self) -> None: + """Test default filter values.""" + filter_ = ScopeFilter() + + assert filter_.scope_types is None + assert filter_.project_ids is None + assert filter_.agent_type_ids is None + assert filter_.include_global is True + + def test_matches_global_scope(self) -> None: + """Test matching global scope.""" + scope = ScopeContext( + scope_type=ScopeLevel.GLOBAL, + scope_id="global", + ) + + filter_ = ScopeFilter(include_global=True) + assert filter_.matches(scope) is True + + filter_ = ScopeFilter(include_global=False) + assert filter_.matches(scope) is False + + def test_matches_scope_type(self) -> None: + """Test matching by scope type.""" + scope = ScopeContext( + scope_type=ScopeLevel.PROJECT, + scope_id="project-1", + ) + + filter_ = ScopeFilter(scope_types=[ScopeLevel.PROJECT]) + assert filter_.matches(scope) is True + + filter_ = ScopeFilter(scope_types=[ScopeLevel.AGENT_TYPE]) + assert filter_.matches(scope) is False + + def test_matches_project_id(self) -> None: + """Test matching by project ID.""" + scope = ScopeContext( + scope_type=ScopeLevel.PROJECT, + scope_id="project-1", + ) + + filter_ = ScopeFilter(project_ids=["project-1", "project-2"]) + assert filter_.matches(scope) is True + + filter_ = ScopeFilter(project_ids=["project-3"]) + assert filter_.matches(scope) is False + + def test_matches_agent_type_id(self) -> None: + """Test matching by agent type ID.""" + scope = ScopeContext( + scope_type=ScopeLevel.AGENT_TYPE, + scope_id="agent-1", + ) + + filter_ = ScopeFilter(agent_type_ids=["agent-1"]) + assert filter_.matches(scope) is True + + filter_ = ScopeFilter(agent_type_ids=["agent-2"]) + assert filter_.matches(scope) is False + + +class TestGetScopeResolver: + """Tests for singleton getter.""" + + def test_returns_instance(self) -> None: + """Test that getter returns instance.""" + resolver = get_scope_resolver() + assert resolver is not None + assert isinstance(resolver, ScopeResolver) + + def test_returns_same_instance(self) -> None: + """Test that getter returns same instance.""" + resolver1 = get_scope_resolver() + resolver2 = get_scope_resolver() + assert resolver1 is resolver2 diff --git a/backend/tests/unit/services/memory/scoping/test_scope.py b/backend/tests/unit/services/memory/scoping/test_scope.py new file mode 100644 index 0000000..aee8a59 --- /dev/null +++ b/backend/tests/unit/services/memory/scoping/test_scope.py @@ -0,0 +1,361 @@ +# tests/unit/services/memory/scoping/test_scope.py +"""Unit tests for scope management.""" + +from uuid import uuid4 + +import pytest + +from app.services.memory.scoping.scope import ( + ScopeManager, + ScopePolicy, + get_scope_manager, +) +from app.services.memory.types import ScopeLevel + + +class TestScopePolicy: + """Tests for ScopePolicy dataclass.""" + + def test_default_values(self) -> None: + """Test default policy values.""" + policy = ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="test-project", + ) + + assert policy.can_read is True + assert policy.can_write is True + assert policy.can_inherit is True + assert policy.allowed_memory_types == ["all"] + + def test_allows_read(self) -> None: + """Test allows_read method.""" + policy = ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="test", + can_read=True, + ) + assert policy.allows_read() is True + + policy.can_read = False + assert policy.allows_read() is False + + def test_allows_write(self) -> None: + """Test allows_write method.""" + policy = ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="test", + can_write=True, + ) + assert policy.allows_write() is True + + policy.can_write = False + assert policy.allows_write() is False + + def test_allows_inherit(self) -> None: + """Test allows_inherit method.""" + policy = ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="test", + can_inherit=True, + ) + assert policy.allows_inherit() is True + + policy.can_inherit = False + assert policy.allows_inherit() is False + + def test_allows_memory_type(self) -> None: + """Test allows_memory_type method.""" + policy = ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="test", + allowed_memory_types=["all"], + ) + assert policy.allows_memory_type("working") is True + assert policy.allows_memory_type("episodic") is True + + policy.allowed_memory_types = ["working", "episodic"] + assert policy.allows_memory_type("working") is True + assert policy.allows_memory_type("episodic") is True + assert policy.allows_memory_type("semantic") is False + + +class TestScopeManager: + """Tests for ScopeManager class.""" + + @pytest.fixture + def manager(self) -> ScopeManager: + """Create a scope manager.""" + return ScopeManager() + + def test_create_global_scope( + self, + manager: ScopeManager, + ) -> None: + """Test creating a global scope.""" + scope = manager.create_scope( + scope_type=ScopeLevel.GLOBAL, + scope_id="global", + ) + + assert scope.scope_type == ScopeLevel.GLOBAL + assert scope.scope_id == "global" + assert scope.parent is None + + def test_create_project_scope( + self, + manager: ScopeManager, + ) -> None: + """Test creating a project scope.""" + global_scope = manager.create_scope( + scope_type=ScopeLevel.GLOBAL, + scope_id="global", + ) + + project_scope = manager.create_scope( + scope_type=ScopeLevel.PROJECT, + scope_id="project-1", + parent=global_scope, + ) + + assert project_scope.scope_type == ScopeLevel.PROJECT + assert project_scope.scope_id == "project-1" + assert project_scope.parent is global_scope + + def test_create_scope_auto_parent( + self, + manager: ScopeManager, + ) -> None: + """Test that non-global scopes auto-create parent chain.""" + scope = manager.create_scope( + scope_type=ScopeLevel.PROJECT, + scope_id="test-project", + ) + + assert scope.scope_type == ScopeLevel.PROJECT + assert scope.parent is not None + assert scope.parent.scope_type == ScopeLevel.GLOBAL + + def test_create_scope_invalid_hierarchy( + self, + manager: ScopeManager, + ) -> None: + """Test that invalid hierarchy raises error.""" + project_scope = manager.create_scope( + scope_type=ScopeLevel.PROJECT, + scope_id="project-1", + ) + + with pytest.raises(ValueError, match="Invalid scope hierarchy"): + manager.create_scope( + scope_type=ScopeLevel.GLOBAL, + scope_id="global", + parent=project_scope, + ) + + def test_create_scope_from_ids( + self, + manager: ScopeManager, + ) -> None: + """Test creating scope from individual IDs.""" + project_id = uuid4() + agent_type_id = uuid4() + + scope = manager.create_scope_from_ids( + project_id=project_id, + agent_type_id=agent_type_id, + ) + + assert scope.scope_type == ScopeLevel.AGENT_TYPE + assert scope.scope_id == str(agent_type_id) + assert scope.parent is not None + assert scope.parent.scope_type == ScopeLevel.PROJECT + + def test_create_scope_from_ids_with_session( + self, + manager: ScopeManager, + ) -> None: + """Test creating scope with session ID.""" + project_id = uuid4() + session_id = "session-123" + + scope = manager.create_scope_from_ids( + project_id=project_id, + session_id=session_id, + ) + + assert scope.scope_type == ScopeLevel.SESSION + assert scope.scope_id == session_id + + def test_get_default_policy( + self, + manager: ScopeManager, + ) -> None: + """Test getting default policy.""" + scope = manager.create_scope( + scope_type=ScopeLevel.PROJECT, + scope_id="test-project", + ) + + policy = manager.get_policy(scope) + + assert policy.can_read is True + assert policy.can_write is True + + def test_set_and_get_policy( + self, + manager: ScopeManager, + ) -> None: + """Test setting and retrieving a policy.""" + scope = manager.create_scope( + scope_type=ScopeLevel.PROJECT, + scope_id="test-project", + ) + + custom_policy = ScopePolicy( + scope_type=ScopeLevel.PROJECT, + scope_id="test-project", + can_write=False, + ) + + manager.set_policy(scope, custom_policy) + retrieved = manager.get_policy(scope) + + assert retrieved.can_write is False + + def test_get_scope_depth( + self, + manager: ScopeManager, + ) -> None: + """Test getting scope depth.""" + assert manager.get_scope_depth(ScopeLevel.GLOBAL) == 0 + assert manager.get_scope_depth(ScopeLevel.PROJECT) == 1 + assert manager.get_scope_depth(ScopeLevel.AGENT_TYPE) == 2 + assert manager.get_scope_depth(ScopeLevel.AGENT_INSTANCE) == 3 + assert manager.get_scope_depth(ScopeLevel.SESSION) == 4 + + def test_get_parent_level( + self, + manager: ScopeManager, + ) -> None: + """Test getting parent level.""" + assert manager.get_parent_level(ScopeLevel.GLOBAL) is None + assert manager.get_parent_level(ScopeLevel.PROJECT) == ScopeLevel.GLOBAL + assert manager.get_parent_level(ScopeLevel.AGENT_TYPE) == ScopeLevel.PROJECT + assert manager.get_parent_level(ScopeLevel.SESSION) == ScopeLevel.AGENT_INSTANCE + + def test_get_child_level( + self, + manager: ScopeManager, + ) -> None: + """Test getting child level.""" + assert manager.get_child_level(ScopeLevel.GLOBAL) == ScopeLevel.PROJECT + assert manager.get_child_level(ScopeLevel.PROJECT) == ScopeLevel.AGENT_TYPE + assert manager.get_child_level(ScopeLevel.SESSION) is None + + def test_is_ancestor( + self, + manager: ScopeManager, + ) -> None: + """Test ancestor checking.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + agent_scope = manager.create_scope( + ScopeLevel.AGENT_TYPE, "agent", parent=project_scope + ) + + assert manager.is_ancestor(global_scope, agent_scope) is True + assert manager.is_ancestor(project_scope, agent_scope) is True + assert manager.is_ancestor(agent_scope, global_scope) is False + assert manager.is_ancestor(agent_scope, project_scope) is False + + def test_get_common_ancestor( + self, + manager: ScopeManager, + ) -> None: + """Test finding common ancestor.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + agent1 = manager.create_scope( + ScopeLevel.AGENT_TYPE, "agent1", parent=project_scope + ) + agent2 = manager.create_scope( + ScopeLevel.AGENT_TYPE, "agent2", parent=project_scope + ) + + common = manager.get_common_ancestor(agent1, agent2) + + assert common is not None + assert common.scope_type == ScopeLevel.PROJECT + + def test_can_access_same_scope( + self, + manager: ScopeManager, + ) -> None: + """Test access to same scope.""" + scope = manager.create_scope(ScopeLevel.PROJECT, "project") + + assert manager.can_access(scope, scope) is True + assert manager.can_access(scope, scope, "write") is True + + def test_can_access_ancestor( + self, + manager: ScopeManager, + ) -> None: + """Test access to ancestor scope.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + + # Child can read from parent + assert manager.can_access(project_scope, global_scope, "read") is True + + def test_cannot_access_descendant( + self, + manager: ScopeManager, + ) -> None: + """Test that parent cannot access child scope.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project_scope = manager.create_scope( + ScopeLevel.PROJECT, "project", parent=global_scope + ) + + # Parent cannot access child + assert manager.can_access(global_scope, project_scope) is False + + def test_cannot_access_sibling( + self, + manager: ScopeManager, + ) -> None: + """Test that sibling scopes cannot access each other.""" + global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") + project1 = manager.create_scope( + ScopeLevel.PROJECT, "project1", parent=global_scope + ) + project2 = manager.create_scope( + ScopeLevel.PROJECT, "project2", parent=global_scope + ) + + assert manager.can_access(project1, project2) is False + assert manager.can_access(project2, project1) is False + + +class TestGetScopeManager: + """Tests for singleton getter.""" + + def test_returns_instance(self) -> None: + """Test that getter returns instance.""" + manager = get_scope_manager() + assert manager is not None + assert isinstance(manager, ScopeManager) + + def test_returns_same_instance(self) -> None: + """Test that getter returns same instance.""" + manager1 = get_scope_manager() + manager2 = get_scope_manager() + assert manager1 is manager2