feat(memory): implement memory scoping with hierarchy and access control (#93)

Add scope management system for hierarchical memory access:
- ScopeManager with hierarchy: Global → Project → Agent Type → Agent Instance → Session
- ScopePolicy for access control (read, write, inherit permissions)
- ScopeResolver for resolving queries across scope hierarchies with inheritance
- ScopeFilter for filtering scopes by type, project, or agent
- Access control enforcement with parent scope visibility
- Deduplication support during resolution across scopes

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 02:39:22 +01:00
parent b818f17418
commit 48ecb40f18
6 changed files with 1892 additions and 1 deletions

View File

@@ -1,3 +1,4 @@
# app/services/memory/scoping/__init__.py
"""
Memory Scoping
@@ -5,4 +6,28 @@ Hierarchical scoping for memory with inheritance:
Global -> Project -> Agent Type -> Agent Instance -> Session
"""
# Will be populated in #93
from .resolver import (
ResolutionOptions,
ResolutionResult,
ScopeFilter,
ScopeResolver,
get_scope_resolver,
)
from .scope import (
ScopeInfo,
ScopeManager,
ScopePolicy,
get_scope_manager,
)
__all__ = [
"ResolutionOptions",
"ResolutionResult",
"ScopeFilter",
"ScopeInfo",
"ScopeManager",
"ScopePolicy",
"ScopeResolver",
"get_scope_manager",
"get_scope_resolver",
]

View File

@@ -0,0 +1,390 @@
# app/services/memory/scoping/resolver.py
"""
Scope Resolution.
Provides utilities for resolving memory queries across scope hierarchies,
implementing inheritance and aggregation of memories from parent scopes.
"""
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, TypeVar
from app.services.memory.types import ScopeContext, ScopeLevel
from .scope import ScopeManager, get_scope_manager
logger = logging.getLogger(__name__)
T = TypeVar("T")
@dataclass
class ResolutionResult[T]:
"""Result of a scope resolution."""
items: list[T]
sources: list[ScopeContext]
total_from_each: dict[str, int] = field(default_factory=dict)
inherited_count: int = 0
own_count: int = 0
@property
def total_count(self) -> int:
"""Get total items from all sources."""
return len(self.items)
@dataclass
class ResolutionOptions:
"""Options for scope resolution."""
include_inherited: bool = True
max_inheritance_depth: int = 5
limit_per_scope: int = 100
total_limit: int = 500
deduplicate: bool = True
deduplicate_key: str | None = None # Field to use for deduplication
class ScopeResolver:
"""
Resolves memory queries across scope hierarchies.
Features:
- Traverse scope hierarchy for inherited memories
- Aggregate results from multiple scope levels
- Apply access control policies
- Support deduplication across scopes
"""
def __init__(
self,
manager: ScopeManager | None = None,
) -> None:
"""
Initialize the resolver.
Args:
manager: Scope manager to use (defaults to singleton)
"""
self._manager = manager or get_scope_manager()
def resolve(
self,
scope: ScopeContext,
fetcher: Callable[[ScopeContext, int], list[T]],
options: ResolutionOptions | None = None,
) -> ResolutionResult[T]:
"""
Resolve memories for a scope, including inherited memories.
Args:
scope: Starting scope
fetcher: Function to fetch items for a scope (scope, limit) -> items
options: Resolution options
Returns:
Resolution result with items from all scopes
"""
opts = options or ResolutionOptions()
all_items: list[T] = []
sources: list[ScopeContext] = []
counts: dict[str, int] = {}
seen_keys: set[Any] = set()
# Collect scopes to query (starting from current, going up to ancestors)
scopes_to_query = self._collect_queryable_scopes(
scope=scope,
max_depth=opts.max_inheritance_depth if opts.include_inherited else 0,
)
own_count = 0
inherited_count = 0
remaining_limit = opts.total_limit
for i, query_scope in enumerate(scopes_to_query):
if remaining_limit <= 0:
break
# Check access policy
policy = self._manager.get_policy(query_scope)
if not policy.allows_read():
continue
if i > 0 and not policy.allows_inherit():
continue
# Fetch items for this scope
scope_limit = min(opts.limit_per_scope, remaining_limit)
items = fetcher(query_scope, scope_limit)
# Apply deduplication
if opts.deduplicate and opts.deduplicate_key:
items = self._deduplicate(items, opts.deduplicate_key, seen_keys)
if items:
all_items.extend(items)
sources.append(query_scope)
key = f"{query_scope.scope_type.value}:{query_scope.scope_id}"
counts[key] = len(items)
if i == 0:
own_count = len(items)
else:
inherited_count += len(items)
remaining_limit -= len(items)
logger.debug(
f"Resolved {len(all_items)} items from {len(sources)} scopes "
f"(own={own_count}, inherited={inherited_count})"
)
return ResolutionResult(
items=all_items[: opts.total_limit],
sources=sources,
total_from_each=counts,
own_count=own_count,
inherited_count=inherited_count,
)
def _collect_queryable_scopes(
self,
scope: ScopeContext,
max_depth: int,
) -> list[ScopeContext]:
"""Collect scopes to query, from current to ancestors."""
scopes: list[ScopeContext] = [scope]
if max_depth <= 0:
return scopes
current = scope.parent
depth = 0
while current is not None and depth < max_depth:
scopes.append(current)
current = current.parent
depth += 1
return scopes
def _deduplicate(
self,
items: list[T],
key_field: str,
seen_keys: set[Any],
) -> list[T]:
"""Remove duplicate items based on a key field."""
unique: list[T] = []
for item in items:
key = getattr(item, key_field, None)
if key is None:
# If no key, include the item
unique.append(item)
elif key not in seen_keys:
seen_keys.add(key)
unique.append(item)
return unique
def get_visible_scopes(
self,
scope: ScopeContext,
) -> list[ScopeContext]:
"""
Get all scopes visible from a given scope.
A scope can see itself and all its ancestors (if inheritance allowed).
Args:
scope: Starting scope
Returns:
List of visible scopes (from most specific to most general)
"""
visible = [scope]
current = scope.parent
while current is not None:
policy = self._manager.get_policy(current)
if policy.allows_inherit():
visible.append(current)
else:
break # Stop at first non-inheritable scope
current = current.parent
return visible
def find_write_scope(
self,
target_level: ScopeLevel,
scope: ScopeContext,
) -> ScopeContext | None:
"""
Find the appropriate scope for writing at a target level.
Walks up the hierarchy to find a scope at the target level
that allows writing.
Args:
target_level: Desired scope level
scope: Starting scope
Returns:
Scope to write to, or None if not found/not allowed
"""
# First check if current scope is at target level
if scope.scope_type == target_level:
policy = self._manager.get_policy(scope)
return scope if policy.allows_write() else None
# Check ancestors
current = scope.parent
while current is not None:
if current.scope_type == target_level:
policy = self._manager.get_policy(current)
return current if policy.allows_write() else None
current = current.parent
return None
def resolve_scope_from_memory(
self,
memory_type: str,
project_id: str | None = None,
agent_type_id: str | None = None,
agent_instance_id: str | None = None,
session_id: str | None = None,
) -> tuple[ScopeContext, ScopeLevel]:
"""
Resolve the appropriate scope for a memory operation.
Different memory types have different scope requirements:
- working: Session or Agent Instance
- episodic: Agent Instance or Project
- semantic: Project or Global
- procedural: Agent Type or Project
Args:
memory_type: Type of memory
project_id: Project ID
agent_type_id: Agent type ID
agent_instance_id: Agent instance ID
session_id: Session ID
Returns:
Tuple of (scope context, recommended level)
"""
# Build full scope chain
scope = self._manager.create_scope_from_ids(
project_id=project_id if project_id else None, # type: ignore[arg-type]
agent_type_id=agent_type_id if agent_type_id else None, # type: ignore[arg-type]
agent_instance_id=agent_instance_id if agent_instance_id else None, # type: ignore[arg-type]
session_id=session_id,
)
# Determine recommended level based on memory type
recommended = self._get_recommended_level(memory_type)
return scope, recommended
def _get_recommended_level(self, memory_type: str) -> ScopeLevel:
"""Get recommended scope level for a memory type."""
recommendations = {
"working": ScopeLevel.SESSION,
"episodic": ScopeLevel.AGENT_INSTANCE,
"semantic": ScopeLevel.PROJECT,
"procedural": ScopeLevel.AGENT_TYPE,
}
return recommendations.get(memory_type, ScopeLevel.PROJECT)
def validate_write_access(
self,
scope: ScopeContext,
memory_type: str,
) -> bool:
"""
Validate that writing is allowed for the given scope and memory type.
Args:
scope: Scope to validate
memory_type: Type of memory to write
Returns:
True if write is allowed
"""
policy = self._manager.get_policy(scope)
if not policy.allows_write():
return False
if not policy.allows_memory_type(memory_type):
return False
return True
def get_scope_chain(
self,
scope: ScopeContext,
) -> list[tuple[ScopeLevel, str]]:
"""
Get the scope chain as a list of (level, id) tuples.
Args:
scope: Scope to get chain for
Returns:
List of (level, id) tuples from root to leaf
"""
chain: list[tuple[ScopeLevel, str]] = []
# Get full hierarchy
hierarchy = scope.get_hierarchy()
for ctx in hierarchy:
chain.append((ctx.scope_type, ctx.scope_id))
return chain
@dataclass
class ScopeFilter:
"""Filter for querying across scopes."""
scope_types: list[ScopeLevel] | None = None
project_ids: list[str] | None = None
agent_type_ids: list[str] | None = None
include_global: bool = True
def matches(self, scope: ScopeContext) -> bool:
"""Check if a scope matches this filter."""
if self.scope_types and scope.scope_type not in self.scope_types:
return False
if scope.scope_type == ScopeLevel.GLOBAL:
return self.include_global
if scope.scope_type == ScopeLevel.PROJECT:
if self.project_ids and scope.scope_id not in self.project_ids:
return False
if scope.scope_type == ScopeLevel.AGENT_TYPE:
if self.agent_type_ids and scope.scope_id not in self.agent_type_ids:
return False
return True
# Singleton resolver instance
_resolver: ScopeResolver | None = None
def get_scope_resolver() -> ScopeResolver:
"""Get the singleton scope resolver instance."""
global _resolver
if _resolver is None:
_resolver = ScopeResolver()
return _resolver

View File

@@ -0,0 +1,460 @@
# app/services/memory/scoping/scope.py
"""
Scope Management.
Provides utilities for managing memory scopes with hierarchical inheritance:
Global -> Project -> Agent Type -> Agent Instance -> Session
"""
import logging
from dataclasses import dataclass, field
from typing import Any, ClassVar
from uuid import UUID
from app.services.memory.types import ScopeContext, ScopeLevel
logger = logging.getLogger(__name__)
@dataclass
class ScopePolicy:
"""Access control policy for a scope."""
scope_type: ScopeLevel
scope_id: str
can_read: bool = True
can_write: bool = True
can_inherit: bool = True
allowed_memory_types: list[str] = field(default_factory=lambda: ["all"])
metadata: dict[str, Any] = field(default_factory=dict)
def allows_read(self) -> bool:
"""Check if reading is allowed."""
return self.can_read
def allows_write(self) -> bool:
"""Check if writing is allowed."""
return self.can_write
def allows_inherit(self) -> bool:
"""Check if inheritance from parent is allowed."""
return self.can_inherit
def allows_memory_type(self, memory_type: str) -> bool:
"""Check if a specific memory type is allowed."""
return (
"all" in self.allowed_memory_types
or memory_type in self.allowed_memory_types
)
@dataclass
class ScopeInfo:
"""Information about a scope including its hierarchy."""
context: ScopeContext
policy: ScopePolicy
parent_info: "ScopeInfo | None" = None
child_count: int = 0
memory_count: int = 0
@property
def depth(self) -> int:
"""Get the depth of this scope in the hierarchy."""
count = 0
current = self.parent_info
while current is not None:
count += 1
current = current.parent_info
return count
class ScopeManager:
"""
Manages memory scopes and their hierarchies.
Provides:
- Scope creation and validation
- Hierarchy management
- Access control policy management
- Scope inheritance rules
"""
# Order of scope levels from root to leaf
SCOPE_ORDER: ClassVar[list[ScopeLevel]] = [
ScopeLevel.GLOBAL,
ScopeLevel.PROJECT,
ScopeLevel.AGENT_TYPE,
ScopeLevel.AGENT_INSTANCE,
ScopeLevel.SESSION,
]
def __init__(self) -> None:
"""Initialize the scope manager."""
# In-memory policy cache (would be backed by database in production)
self._policies: dict[str, ScopePolicy] = {}
self._default_policies = self._create_default_policies()
def _create_default_policies(self) -> dict[ScopeLevel, ScopePolicy]:
"""Create default policies for each scope level."""
return {
ScopeLevel.GLOBAL: ScopePolicy(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
can_read=True,
can_write=False, # Global writes require special permission
can_inherit=True,
),
ScopeLevel.PROJECT: ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="default",
can_read=True,
can_write=True,
can_inherit=True,
),
ScopeLevel.AGENT_TYPE: ScopePolicy(
scope_type=ScopeLevel.AGENT_TYPE,
scope_id="default",
can_read=True,
can_write=True,
can_inherit=True,
),
ScopeLevel.AGENT_INSTANCE: ScopePolicy(
scope_type=ScopeLevel.AGENT_INSTANCE,
scope_id="default",
can_read=True,
can_write=True,
can_inherit=True,
),
ScopeLevel.SESSION: ScopePolicy(
scope_type=ScopeLevel.SESSION,
scope_id="default",
can_read=True,
can_write=True,
can_inherit=True,
allowed_memory_types=["working"], # Sessions only allow working memory
),
}
def create_scope(
self,
scope_type: ScopeLevel,
scope_id: str,
parent: ScopeContext | None = None,
) -> ScopeContext:
"""
Create a new scope context.
Args:
scope_type: Level of the scope
scope_id: Unique identifier within the level
parent: Optional parent scope
Returns:
Created scope context
Raises:
ValueError: If scope hierarchy is invalid
"""
# Validate hierarchy
if parent is not None:
self._validate_parent_child(parent.scope_type, scope_type)
# For non-global scopes without parent, auto-create parent chain
if parent is None and scope_type != ScopeLevel.GLOBAL:
parent = self._create_parent_chain(scope_type, scope_id)
context = ScopeContext(
scope_type=scope_type,
scope_id=scope_id,
parent=parent,
)
logger.debug(f"Created scope: {scope_type.value}:{scope_id}")
return context
def _validate_parent_child(
self,
parent_type: ScopeLevel,
child_type: ScopeLevel,
) -> None:
"""Validate that parent-child relationship is valid."""
parent_idx = self.SCOPE_ORDER.index(parent_type)
child_idx = self.SCOPE_ORDER.index(child_type)
if child_idx <= parent_idx:
raise ValueError(
f"Invalid scope hierarchy: {child_type.value} cannot be child of {parent_type.value}"
)
# Allow skipping levels (e.g., PROJECT -> SESSION is valid)
# This enables flexible scope structures
def _create_parent_chain(
self,
target_type: ScopeLevel,
scope_id: str,
) -> ScopeContext:
"""Create parent scope chain up to target type."""
target_idx = self.SCOPE_ORDER.index(target_type)
# Start from global and build chain
current: ScopeContext | None = None
for i in range(target_idx):
level = self.SCOPE_ORDER[i]
if level == ScopeLevel.GLOBAL:
level_id = "global"
else:
# Use a default ID for intermediate levels
level_id = f"default_{level.value}"
current = ScopeContext(
scope_type=level,
scope_id=level_id,
parent=current,
)
return current # type: ignore[return-value]
def create_scope_from_ids(
self,
project_id: UUID | None = None,
agent_type_id: UUID | None = None,
agent_instance_id: UUID | None = None,
session_id: str | None = None,
) -> ScopeContext:
"""
Create a scope context from individual IDs.
Automatically determines the most specific scope level
based on provided IDs.
Args:
project_id: Project UUID
agent_type_id: Agent type UUID
agent_instance_id: Agent instance UUID
session_id: Session identifier
Returns:
Scope context for the most specific level
"""
# Build scope chain from most general to most specific
current: ScopeContext = ScopeContext(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
parent=None,
)
if project_id is not None:
current = ScopeContext(
scope_type=ScopeLevel.PROJECT,
scope_id=str(project_id),
parent=current,
)
if agent_type_id is not None:
current = ScopeContext(
scope_type=ScopeLevel.AGENT_TYPE,
scope_id=str(agent_type_id),
parent=current,
)
if agent_instance_id is not None:
current = ScopeContext(
scope_type=ScopeLevel.AGENT_INSTANCE,
scope_id=str(agent_instance_id),
parent=current,
)
if session_id is not None:
current = ScopeContext(
scope_type=ScopeLevel.SESSION,
scope_id=session_id,
parent=current,
)
return current
def get_policy(
self,
scope: ScopeContext,
) -> ScopePolicy:
"""
Get the access policy for a scope.
Args:
scope: Scope to get policy for
Returns:
Policy for the scope
"""
key = self._scope_key(scope)
if key in self._policies:
return self._policies[key]
# Return default policy for the scope level
return self._default_policies.get(
scope.scope_type,
ScopePolicy(
scope_type=scope.scope_type,
scope_id=scope.scope_id,
),
)
def set_policy(
self,
scope: ScopeContext,
policy: ScopePolicy,
) -> None:
"""
Set the access policy for a scope.
Args:
scope: Scope to set policy for
policy: Policy to apply
"""
key = self._scope_key(scope)
self._policies[key] = policy
logger.info(f"Set policy for scope {key}")
def _scope_key(self, scope: ScopeContext) -> str:
"""Generate a unique key for a scope."""
return f"{scope.scope_type.value}:{scope.scope_id}"
def get_scope_depth(self, scope_type: ScopeLevel) -> int:
"""Get the depth of a scope level in the hierarchy."""
return self.SCOPE_ORDER.index(scope_type)
def get_parent_level(self, scope_type: ScopeLevel) -> ScopeLevel | None:
"""Get the parent scope level for a given level."""
idx = self.SCOPE_ORDER.index(scope_type)
if idx == 0:
return None
return self.SCOPE_ORDER[idx - 1]
def get_child_level(self, scope_type: ScopeLevel) -> ScopeLevel | None:
"""Get the child scope level for a given level."""
idx = self.SCOPE_ORDER.index(scope_type)
if idx >= len(self.SCOPE_ORDER) - 1:
return None
return self.SCOPE_ORDER[idx + 1]
def is_ancestor(
self,
potential_ancestor: ScopeContext,
descendant: ScopeContext,
) -> bool:
"""
Check if one scope is an ancestor of another.
Args:
potential_ancestor: Scope to check as ancestor
descendant: Scope to check as descendant
Returns:
True if ancestor relationship exists
"""
current = descendant.parent
while current is not None:
if (
current.scope_type == potential_ancestor.scope_type
and current.scope_id == potential_ancestor.scope_id
):
return True
current = current.parent
return False
def get_common_ancestor(
self,
scope_a: ScopeContext,
scope_b: ScopeContext,
) -> ScopeContext | None:
"""
Find the nearest common ancestor of two scopes.
Args:
scope_a: First scope
scope_b: Second scope
Returns:
Common ancestor or None if none exists
"""
# Get ancestors of scope_a
ancestors_a: set[str] = set()
current: ScopeContext | None = scope_a
while current is not None:
ancestors_a.add(self._scope_key(current))
current = current.parent
# Find first ancestor of scope_b that's in ancestors_a
current = scope_b
while current is not None:
if self._scope_key(current) in ancestors_a:
return current
current = current.parent
return None
def can_access(
self,
accessor_scope: ScopeContext,
target_scope: ScopeContext,
operation: str = "read",
) -> bool:
"""
Check if accessor scope can access target scope.
Access rules:
- A scope can always access itself
- A scope can access ancestors (if inheritance allowed)
- A scope CANNOT access descendants (privacy)
- Sibling scopes cannot access each other
Args:
accessor_scope: Scope attempting access
target_scope: Scope being accessed
operation: Type of operation (read/write)
Returns:
True if access is allowed
"""
# Same scope - always allowed
if (
accessor_scope.scope_type == target_scope.scope_type
and accessor_scope.scope_id == target_scope.scope_id
):
policy = self.get_policy(target_scope)
if operation == "write":
return policy.allows_write()
return policy.allows_read()
# Check if target is ancestor (inheritance)
if self.is_ancestor(target_scope, accessor_scope):
policy = self.get_policy(target_scope)
if not policy.allows_inherit():
return False
if operation == "write":
return policy.allows_write()
return policy.allows_read()
# Check if accessor is ancestor of target (downward access)
# This is NOT allowed - parents cannot access children's memories
if self.is_ancestor(accessor_scope, target_scope):
return False
# Sibling scopes cannot access each other
return False
# Singleton manager instance
_manager: ScopeManager | None = None
def get_scope_manager() -> ScopeManager:
"""Get the singleton scope manager instance."""
global _manager
if _manager is None:
_manager = ScopeManager()
return _manager

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/scoping/__init__.py
"""Unit tests for memory scoping."""

View File

@@ -0,0 +1,653 @@
# tests/unit/services/memory/scoping/test_resolver.py
"""Unit tests for scope resolution."""
from dataclasses import dataclass
from uuid import uuid4
import pytest
from app.services.memory.scoping.resolver import (
ResolutionOptions,
ResolutionResult,
ScopeFilter,
ScopeResolver,
get_scope_resolver,
)
from app.services.memory.scoping.scope import ScopeManager, ScopePolicy
from app.services.memory.types import ScopeContext, ScopeLevel
@dataclass
class MockItem:
"""Mock item for testing resolution."""
id: str
name: str
scope_id: str
class TestResolutionResult:
"""Tests for ResolutionResult dataclass."""
def test_total_count(self) -> None:
"""Test total_count property."""
result = ResolutionResult[MockItem](
items=[
MockItem(id="1", name="a", scope_id="s1"),
MockItem(id="2", name="b", scope_id="s2"),
],
sources=[],
)
assert result.total_count == 2
def test_empty_result(self) -> None:
"""Test empty result."""
result = ResolutionResult[MockItem](
items=[],
sources=[],
)
assert result.total_count == 0
assert result.inherited_count == 0
assert result.own_count == 0
class TestResolutionOptions:
"""Tests for ResolutionOptions dataclass."""
def test_default_values(self) -> None:
"""Test default option values."""
options = ResolutionOptions()
assert options.include_inherited is True
assert options.max_inheritance_depth == 5
assert options.limit_per_scope == 100
assert options.total_limit == 500
assert options.deduplicate is True
assert options.deduplicate_key is None
def test_custom_values(self) -> None:
"""Test custom option values."""
options = ResolutionOptions(
include_inherited=False,
max_inheritance_depth=3,
limit_per_scope=50,
total_limit=200,
deduplicate=False,
deduplicate_key="id",
)
assert options.include_inherited is False
assert options.max_inheritance_depth == 3
assert options.limit_per_scope == 50
assert options.total_limit == 200
assert options.deduplicate is False
assert options.deduplicate_key == "id"
class TestScopeResolver:
"""Tests for ScopeResolver class."""
@pytest.fixture
def manager(self) -> ScopeManager:
"""Create a scope manager."""
return ScopeManager()
@pytest.fixture
def resolver(self, manager: ScopeManager) -> ScopeResolver:
"""Create a scope resolver."""
return ScopeResolver(manager=manager)
def test_resolve_single_scope(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test resolving from a single scope."""
scope = manager.create_scope(ScopeLevel.PROJECT, "project-1")
def fetcher(s: ScopeContext, limit: int) -> list[MockItem]:
if s.scope_id == "project-1":
return [MockItem(id="1", name="item1", scope_id="project-1")]
return []
result = resolver.resolve(
scope=scope,
fetcher=fetcher,
options=ResolutionOptions(include_inherited=False),
)
assert result.total_count == 1
assert result.own_count == 1
assert result.inherited_count == 0
def test_resolve_with_inheritance(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test resolving with scope inheritance."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project-1", parent=global_scope
)
def fetcher(s: ScopeContext, limit: int) -> list[MockItem]:
if s.scope_id == "project-1":
return [MockItem(id="1", name="project-item", scope_id="project-1")]
elif s.scope_id == "global":
return [MockItem(id="2", name="global-item", scope_id="global")]
return []
result = resolver.resolve(
scope=project_scope,
fetcher=fetcher,
options=ResolutionOptions(include_inherited=True),
)
assert result.total_count == 2
assert result.own_count == 1
assert result.inherited_count == 1
def test_resolve_respects_depth_limit(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test that resolution respects max inheritance depth."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
agent_scope = manager.create_scope(
ScopeLevel.AGENT_TYPE, "agent", parent=project_scope
)
instance_scope = manager.create_scope(
ScopeLevel.AGENT_INSTANCE, "instance", parent=agent_scope
)
session_scope = manager.create_scope(
ScopeLevel.SESSION, "session", parent=instance_scope
)
items_per_scope = {
"session": [MockItem(id="1", name="s", scope_id="session")],
"instance": [MockItem(id="2", name="i", scope_id="instance")],
"agent": [MockItem(id="3", name="a", scope_id="agent")],
"project": [MockItem(id="4", name="p", scope_id="project")],
"global": [MockItem(id="5", name="g", scope_id="global")],
}
def fetcher(s: ScopeContext, limit: int) -> list[MockItem]:
return items_per_scope.get(s.scope_id, [])
# Depth 1 should get session + instance
result = resolver.resolve(
scope=session_scope,
fetcher=fetcher,
options=ResolutionOptions(max_inheritance_depth=1),
)
assert result.total_count == 2
def test_resolve_respects_total_limit(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test that resolution respects total limit."""
scope = manager.create_scope(ScopeLevel.PROJECT, "project")
def fetcher(s: ScopeContext, limit: int) -> list[MockItem]:
return [
MockItem(id=str(i), name=f"item-{i}", scope_id="project")
for i in range(10)
]
result = resolver.resolve(
scope=scope,
fetcher=fetcher,
options=ResolutionOptions(total_limit=5),
)
assert result.total_count == 5
def test_resolve_deduplicates_by_key(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test deduplication by key field."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
def fetcher(s: ScopeContext, limit: int) -> list[MockItem]:
if s.scope_id == "project":
return [MockItem(id="1", name="project-ver", scope_id="project")]
elif s.scope_id == "global":
# Same ID, should be deduplicated
return [MockItem(id="1", name="global-ver", scope_id="global")]
return []
result = resolver.resolve(
scope=project_scope,
fetcher=fetcher,
options=ResolutionOptions(deduplicate=True, deduplicate_key="id"),
)
# Should only have the project version (encountered first)
assert result.total_count == 1
assert result.items[0].name == "project-ver"
def test_resolve_skips_non_readable_scopes(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test that non-readable scopes are skipped."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
# Set global as non-readable
manager.set_policy(
global_scope,
ScopePolicy(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
can_read=False,
),
)
def fetcher(s: ScopeContext, limit: int) -> list[MockItem]:
if s.scope_id == "project":
return [MockItem(id="1", name="project-item", scope_id="project")]
elif s.scope_id == "global":
return [MockItem(id="2", name="global-item", scope_id="global")]
return []
result = resolver.resolve(
scope=project_scope,
fetcher=fetcher,
)
# Should only have project item
assert result.total_count == 1
assert result.items[0].scope_id == "project"
def test_resolve_skips_non_inheritable_scopes(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test that non-inheritable parent scopes stop inheritance."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
# Set global as non-inheritable
manager.set_policy(
global_scope,
ScopePolicy(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
can_inherit=False,
),
)
def fetcher(s: ScopeContext, limit: int) -> list[MockItem]:
if s.scope_id == "project":
return [MockItem(id="1", name="project-item", scope_id="project")]
elif s.scope_id == "global":
return [MockItem(id="2", name="global-item", scope_id="global")]
return []
result = resolver.resolve(
scope=project_scope,
fetcher=fetcher,
)
# Should only have project item
assert result.total_count == 1
def test_get_visible_scopes(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test getting visible scopes."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
agent_scope = manager.create_scope(
ScopeLevel.AGENT_TYPE, "agent", parent=project_scope
)
visible = resolver.get_visible_scopes(agent_scope)
assert len(visible) == 3
assert visible[0].scope_id == "agent"
assert visible[1].scope_id == "project"
assert visible[2].scope_id == "global"
def test_get_visible_scopes_stops_at_non_inheritable(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test that visible scopes stop at non-inheritable parent."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
agent_scope = manager.create_scope(
ScopeLevel.AGENT_TYPE, "agent", parent=project_scope
)
# Make project non-inheritable
manager.set_policy(
project_scope,
ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="project",
can_inherit=False,
),
)
visible = resolver.get_visible_scopes(agent_scope)
# Should stop at project (exclusive)
assert len(visible) == 1
assert visible[0].scope_id == "agent"
def test_find_write_scope_same_level(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test finding write scope at same level."""
scope = manager.create_scope(ScopeLevel.PROJECT, "project")
result = resolver.find_write_scope(ScopeLevel.PROJECT, scope)
assert result is not None
assert result.scope_id == "project"
def test_find_write_scope_ancestor(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test finding write scope in ancestors."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
agent_scope = manager.create_scope(
ScopeLevel.AGENT_TYPE, "agent", parent=project_scope
)
result = resolver.find_write_scope(ScopeLevel.PROJECT, agent_scope)
assert result is not None
assert result.scope_id == "project"
def test_find_write_scope_not_found(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test finding write scope when not in hierarchy."""
scope = manager.create_scope(ScopeLevel.PROJECT, "project")
# Looking for session level, but we're at project
result = resolver.find_write_scope(ScopeLevel.SESSION, scope)
assert result is None
def test_find_write_scope_respects_write_policy(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test that find_write_scope respects write policy."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
# Make project read-only
manager.set_policy(
project_scope,
ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="project",
can_write=False,
),
)
result = resolver.find_write_scope(ScopeLevel.PROJECT, project_scope)
assert result is None
def test_resolve_scope_from_memory_working(
self,
resolver: ScopeResolver,
) -> None:
"""Test resolving scope for working memory."""
project_id = str(uuid4())
session_id = "session-123"
scope, level = resolver.resolve_scope_from_memory(
memory_type="working",
project_id=project_id,
session_id=session_id,
)
assert scope.scope_type == ScopeLevel.SESSION
assert level == ScopeLevel.SESSION
def test_resolve_scope_from_memory_episodic(
self,
resolver: ScopeResolver,
) -> None:
"""Test resolving scope for episodic memory."""
project_id = str(uuid4())
agent_instance_id = str(uuid4())
scope, level = resolver.resolve_scope_from_memory(
memory_type="episodic",
project_id=project_id,
agent_instance_id=agent_instance_id,
)
assert scope.scope_type == ScopeLevel.AGENT_INSTANCE
assert level == ScopeLevel.AGENT_INSTANCE
def test_resolve_scope_from_memory_semantic(
self,
resolver: ScopeResolver,
) -> None:
"""Test resolving scope for semantic memory."""
project_id = str(uuid4())
scope, level = resolver.resolve_scope_from_memory(
memory_type="semantic",
project_id=project_id,
)
assert scope.scope_type == ScopeLevel.PROJECT
assert level == ScopeLevel.PROJECT
def test_resolve_scope_from_memory_procedural(
self,
resolver: ScopeResolver,
) -> None:
"""Test resolving scope for procedural memory."""
project_id = str(uuid4())
agent_type_id = str(uuid4())
scope, level = resolver.resolve_scope_from_memory(
memory_type="procedural",
project_id=project_id,
agent_type_id=agent_type_id,
)
assert scope.scope_type == ScopeLevel.AGENT_TYPE
assert level == ScopeLevel.AGENT_TYPE
def test_validate_write_access_allowed(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test write access validation when allowed."""
scope = manager.create_scope(ScopeLevel.PROJECT, "project")
assert resolver.validate_write_access(scope, "semantic") is True
def test_validate_write_access_denied_by_policy(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test write access denied by policy."""
scope = manager.create_scope(ScopeLevel.PROJECT, "project")
manager.set_policy(
scope,
ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="project",
can_write=False,
),
)
assert resolver.validate_write_access(scope, "semantic") is False
def test_validate_write_access_denied_by_memory_type(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test write access denied by memory type restriction."""
scope = manager.create_scope(ScopeLevel.PROJECT, "project")
manager.set_policy(
scope,
ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="project",
allowed_memory_types=["episodic"], # Only episodic allowed
),
)
assert resolver.validate_write_access(scope, "semantic") is False
assert resolver.validate_write_access(scope, "episodic") is True
def test_get_scope_chain(
self,
resolver: ScopeResolver,
manager: ScopeManager,
) -> None:
"""Test getting scope chain."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
agent_scope = manager.create_scope(
ScopeLevel.AGENT_TYPE, "agent", parent=project_scope
)
chain = resolver.get_scope_chain(agent_scope)
assert len(chain) == 3
assert chain[0] == (ScopeLevel.GLOBAL, "global")
assert chain[1] == (ScopeLevel.PROJECT, "project")
assert chain[2] == (ScopeLevel.AGENT_TYPE, "agent")
class TestScopeFilter:
"""Tests for ScopeFilter dataclass."""
def test_default_values(self) -> None:
"""Test default filter values."""
filter_ = ScopeFilter()
assert filter_.scope_types is None
assert filter_.project_ids is None
assert filter_.agent_type_ids is None
assert filter_.include_global is True
def test_matches_global_scope(self) -> None:
"""Test matching global scope."""
scope = ScopeContext(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
)
filter_ = ScopeFilter(include_global=True)
assert filter_.matches(scope) is True
filter_ = ScopeFilter(include_global=False)
assert filter_.matches(scope) is False
def test_matches_scope_type(self) -> None:
"""Test matching by scope type."""
scope = ScopeContext(
scope_type=ScopeLevel.PROJECT,
scope_id="project-1",
)
filter_ = ScopeFilter(scope_types=[ScopeLevel.PROJECT])
assert filter_.matches(scope) is True
filter_ = ScopeFilter(scope_types=[ScopeLevel.AGENT_TYPE])
assert filter_.matches(scope) is False
def test_matches_project_id(self) -> None:
"""Test matching by project ID."""
scope = ScopeContext(
scope_type=ScopeLevel.PROJECT,
scope_id="project-1",
)
filter_ = ScopeFilter(project_ids=["project-1", "project-2"])
assert filter_.matches(scope) is True
filter_ = ScopeFilter(project_ids=["project-3"])
assert filter_.matches(scope) is False
def test_matches_agent_type_id(self) -> None:
"""Test matching by agent type ID."""
scope = ScopeContext(
scope_type=ScopeLevel.AGENT_TYPE,
scope_id="agent-1",
)
filter_ = ScopeFilter(agent_type_ids=["agent-1"])
assert filter_.matches(scope) is True
filter_ = ScopeFilter(agent_type_ids=["agent-2"])
assert filter_.matches(scope) is False
class TestGetScopeResolver:
"""Tests for singleton getter."""
def test_returns_instance(self) -> None:
"""Test that getter returns instance."""
resolver = get_scope_resolver()
assert resolver is not None
assert isinstance(resolver, ScopeResolver)
def test_returns_same_instance(self) -> None:
"""Test that getter returns same instance."""
resolver1 = get_scope_resolver()
resolver2 = get_scope_resolver()
assert resolver1 is resolver2

View File

@@ -0,0 +1,361 @@
# tests/unit/services/memory/scoping/test_scope.py
"""Unit tests for scope management."""
from uuid import uuid4
import pytest
from app.services.memory.scoping.scope import (
ScopeManager,
ScopePolicy,
get_scope_manager,
)
from app.services.memory.types import ScopeLevel
class TestScopePolicy:
"""Tests for ScopePolicy dataclass."""
def test_default_values(self) -> None:
"""Test default policy values."""
policy = ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="test-project",
)
assert policy.can_read is True
assert policy.can_write is True
assert policy.can_inherit is True
assert policy.allowed_memory_types == ["all"]
def test_allows_read(self) -> None:
"""Test allows_read method."""
policy = ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="test",
can_read=True,
)
assert policy.allows_read() is True
policy.can_read = False
assert policy.allows_read() is False
def test_allows_write(self) -> None:
"""Test allows_write method."""
policy = ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="test",
can_write=True,
)
assert policy.allows_write() is True
policy.can_write = False
assert policy.allows_write() is False
def test_allows_inherit(self) -> None:
"""Test allows_inherit method."""
policy = ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="test",
can_inherit=True,
)
assert policy.allows_inherit() is True
policy.can_inherit = False
assert policy.allows_inherit() is False
def test_allows_memory_type(self) -> None:
"""Test allows_memory_type method."""
policy = ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="test",
allowed_memory_types=["all"],
)
assert policy.allows_memory_type("working") is True
assert policy.allows_memory_type("episodic") is True
policy.allowed_memory_types = ["working", "episodic"]
assert policy.allows_memory_type("working") is True
assert policy.allows_memory_type("episodic") is True
assert policy.allows_memory_type("semantic") is False
class TestScopeManager:
"""Tests for ScopeManager class."""
@pytest.fixture
def manager(self) -> ScopeManager:
"""Create a scope manager."""
return ScopeManager()
def test_create_global_scope(
self,
manager: ScopeManager,
) -> None:
"""Test creating a global scope."""
scope = manager.create_scope(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
)
assert scope.scope_type == ScopeLevel.GLOBAL
assert scope.scope_id == "global"
assert scope.parent is None
def test_create_project_scope(
self,
manager: ScopeManager,
) -> None:
"""Test creating a project scope."""
global_scope = manager.create_scope(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
)
project_scope = manager.create_scope(
scope_type=ScopeLevel.PROJECT,
scope_id="project-1",
parent=global_scope,
)
assert project_scope.scope_type == ScopeLevel.PROJECT
assert project_scope.scope_id == "project-1"
assert project_scope.parent is global_scope
def test_create_scope_auto_parent(
self,
manager: ScopeManager,
) -> None:
"""Test that non-global scopes auto-create parent chain."""
scope = manager.create_scope(
scope_type=ScopeLevel.PROJECT,
scope_id="test-project",
)
assert scope.scope_type == ScopeLevel.PROJECT
assert scope.parent is not None
assert scope.parent.scope_type == ScopeLevel.GLOBAL
def test_create_scope_invalid_hierarchy(
self,
manager: ScopeManager,
) -> None:
"""Test that invalid hierarchy raises error."""
project_scope = manager.create_scope(
scope_type=ScopeLevel.PROJECT,
scope_id="project-1",
)
with pytest.raises(ValueError, match="Invalid scope hierarchy"):
manager.create_scope(
scope_type=ScopeLevel.GLOBAL,
scope_id="global",
parent=project_scope,
)
def test_create_scope_from_ids(
self,
manager: ScopeManager,
) -> None:
"""Test creating scope from individual IDs."""
project_id = uuid4()
agent_type_id = uuid4()
scope = manager.create_scope_from_ids(
project_id=project_id,
agent_type_id=agent_type_id,
)
assert scope.scope_type == ScopeLevel.AGENT_TYPE
assert scope.scope_id == str(agent_type_id)
assert scope.parent is not None
assert scope.parent.scope_type == ScopeLevel.PROJECT
def test_create_scope_from_ids_with_session(
self,
manager: ScopeManager,
) -> None:
"""Test creating scope with session ID."""
project_id = uuid4()
session_id = "session-123"
scope = manager.create_scope_from_ids(
project_id=project_id,
session_id=session_id,
)
assert scope.scope_type == ScopeLevel.SESSION
assert scope.scope_id == session_id
def test_get_default_policy(
self,
manager: ScopeManager,
) -> None:
"""Test getting default policy."""
scope = manager.create_scope(
scope_type=ScopeLevel.PROJECT,
scope_id="test-project",
)
policy = manager.get_policy(scope)
assert policy.can_read is True
assert policy.can_write is True
def test_set_and_get_policy(
self,
manager: ScopeManager,
) -> None:
"""Test setting and retrieving a policy."""
scope = manager.create_scope(
scope_type=ScopeLevel.PROJECT,
scope_id="test-project",
)
custom_policy = ScopePolicy(
scope_type=ScopeLevel.PROJECT,
scope_id="test-project",
can_write=False,
)
manager.set_policy(scope, custom_policy)
retrieved = manager.get_policy(scope)
assert retrieved.can_write is False
def test_get_scope_depth(
self,
manager: ScopeManager,
) -> None:
"""Test getting scope depth."""
assert manager.get_scope_depth(ScopeLevel.GLOBAL) == 0
assert manager.get_scope_depth(ScopeLevel.PROJECT) == 1
assert manager.get_scope_depth(ScopeLevel.AGENT_TYPE) == 2
assert manager.get_scope_depth(ScopeLevel.AGENT_INSTANCE) == 3
assert manager.get_scope_depth(ScopeLevel.SESSION) == 4
def test_get_parent_level(
self,
manager: ScopeManager,
) -> None:
"""Test getting parent level."""
assert manager.get_parent_level(ScopeLevel.GLOBAL) is None
assert manager.get_parent_level(ScopeLevel.PROJECT) == ScopeLevel.GLOBAL
assert manager.get_parent_level(ScopeLevel.AGENT_TYPE) == ScopeLevel.PROJECT
assert manager.get_parent_level(ScopeLevel.SESSION) == ScopeLevel.AGENT_INSTANCE
def test_get_child_level(
self,
manager: ScopeManager,
) -> None:
"""Test getting child level."""
assert manager.get_child_level(ScopeLevel.GLOBAL) == ScopeLevel.PROJECT
assert manager.get_child_level(ScopeLevel.PROJECT) == ScopeLevel.AGENT_TYPE
assert manager.get_child_level(ScopeLevel.SESSION) is None
def test_is_ancestor(
self,
manager: ScopeManager,
) -> None:
"""Test ancestor checking."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
agent_scope = manager.create_scope(
ScopeLevel.AGENT_TYPE, "agent", parent=project_scope
)
assert manager.is_ancestor(global_scope, agent_scope) is True
assert manager.is_ancestor(project_scope, agent_scope) is True
assert manager.is_ancestor(agent_scope, global_scope) is False
assert manager.is_ancestor(agent_scope, project_scope) is False
def test_get_common_ancestor(
self,
manager: ScopeManager,
) -> None:
"""Test finding common ancestor."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
agent1 = manager.create_scope(
ScopeLevel.AGENT_TYPE, "agent1", parent=project_scope
)
agent2 = manager.create_scope(
ScopeLevel.AGENT_TYPE, "agent2", parent=project_scope
)
common = manager.get_common_ancestor(agent1, agent2)
assert common is not None
assert common.scope_type == ScopeLevel.PROJECT
def test_can_access_same_scope(
self,
manager: ScopeManager,
) -> None:
"""Test access to same scope."""
scope = manager.create_scope(ScopeLevel.PROJECT, "project")
assert manager.can_access(scope, scope) is True
assert manager.can_access(scope, scope, "write") is True
def test_can_access_ancestor(
self,
manager: ScopeManager,
) -> None:
"""Test access to ancestor scope."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
# Child can read from parent
assert manager.can_access(project_scope, global_scope, "read") is True
def test_cannot_access_descendant(
self,
manager: ScopeManager,
) -> None:
"""Test that parent cannot access child scope."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project_scope = manager.create_scope(
ScopeLevel.PROJECT, "project", parent=global_scope
)
# Parent cannot access child
assert manager.can_access(global_scope, project_scope) is False
def test_cannot_access_sibling(
self,
manager: ScopeManager,
) -> None:
"""Test that sibling scopes cannot access each other."""
global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global")
project1 = manager.create_scope(
ScopeLevel.PROJECT, "project1", parent=global_scope
)
project2 = manager.create_scope(
ScopeLevel.PROJECT, "project2", parent=global_scope
)
assert manager.can_access(project1, project2) is False
assert manager.can_access(project2, project1) is False
class TestGetScopeManager:
"""Tests for singleton getter."""
def test_returns_instance(self) -> None:
"""Test that getter returns instance."""
manager = get_scope_manager()
assert manager is not None
assert isinstance(manager, ScopeManager)
def test_returns_same_instance(self) -> None:
"""Test that getter returns same instance."""
manager1 = get_scope_manager()
manager2 = get_scope_manager()
assert manager1 is manager2