# tests/unit/services/memory/scoping/test_resolver.py """Unit tests for scope resolution.""" from dataclasses import dataclass from uuid import uuid4 import pytest from app.services.memory.scoping.resolver import ( ResolutionOptions, ResolutionResult, ScopeFilter, ScopeResolver, get_scope_resolver, ) from app.services.memory.scoping.scope import ScopeManager, ScopePolicy from app.services.memory.types import ScopeContext, ScopeLevel @dataclass class MockItem: """Mock item for testing resolution.""" id: str name: str scope_id: str class TestResolutionResult: """Tests for ResolutionResult dataclass.""" def test_total_count(self) -> None: """Test total_count property.""" result = ResolutionResult[MockItem]( items=[ MockItem(id="1", name="a", scope_id="s1"), MockItem(id="2", name="b", scope_id="s2"), ], sources=[], ) assert result.total_count == 2 def test_empty_result(self) -> None: """Test empty result.""" result = ResolutionResult[MockItem]( items=[], sources=[], ) assert result.total_count == 0 assert result.inherited_count == 0 assert result.own_count == 0 class TestResolutionOptions: """Tests for ResolutionOptions dataclass.""" def test_default_values(self) -> None: """Test default option values.""" options = ResolutionOptions() assert options.include_inherited is True assert options.max_inheritance_depth == 5 assert options.limit_per_scope == 100 assert options.total_limit == 500 assert options.deduplicate is True assert options.deduplicate_key is None def test_custom_values(self) -> None: """Test custom option values.""" options = ResolutionOptions( include_inherited=False, max_inheritance_depth=3, limit_per_scope=50, total_limit=200, deduplicate=False, deduplicate_key="id", ) assert options.include_inherited is False assert options.max_inheritance_depth == 3 assert options.limit_per_scope == 50 assert options.total_limit == 200 assert options.deduplicate is False assert options.deduplicate_key == "id" class TestScopeResolver: """Tests for ScopeResolver class.""" @pytest.fixture def manager(self) -> ScopeManager: """Create a scope manager.""" return ScopeManager() @pytest.fixture def resolver(self, manager: ScopeManager) -> ScopeResolver: """Create a scope resolver.""" return ScopeResolver(manager=manager) def test_resolve_single_scope( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test resolving from a single scope.""" scope = manager.create_scope(ScopeLevel.PROJECT, "project-1") def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: if s.scope_id == "project-1": return [MockItem(id="1", name="item1", scope_id="project-1")] return [] result = resolver.resolve( scope=scope, fetcher=fetcher, options=ResolutionOptions(include_inherited=False), ) assert result.total_count == 1 assert result.own_count == 1 assert result.inherited_count == 0 def test_resolve_with_inheritance( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test resolving with scope inheritance.""" global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") project_scope = manager.create_scope( ScopeLevel.PROJECT, "project-1", parent=global_scope ) def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: if s.scope_id == "project-1": return [MockItem(id="1", name="project-item", scope_id="project-1")] elif s.scope_id == "global": return [MockItem(id="2", name="global-item", scope_id="global")] return [] result = resolver.resolve( scope=project_scope, fetcher=fetcher, options=ResolutionOptions(include_inherited=True), ) assert result.total_count == 2 assert result.own_count == 1 assert result.inherited_count == 1 def test_resolve_respects_depth_limit( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test that resolution respects max inheritance depth.""" global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") project_scope = manager.create_scope( ScopeLevel.PROJECT, "project", parent=global_scope ) agent_scope = manager.create_scope( ScopeLevel.AGENT_TYPE, "agent", parent=project_scope ) instance_scope = manager.create_scope( ScopeLevel.AGENT_INSTANCE, "instance", parent=agent_scope ) session_scope = manager.create_scope( ScopeLevel.SESSION, "session", parent=instance_scope ) items_per_scope = { "session": [MockItem(id="1", name="s", scope_id="session")], "instance": [MockItem(id="2", name="i", scope_id="instance")], "agent": [MockItem(id="3", name="a", scope_id="agent")], "project": [MockItem(id="4", name="p", scope_id="project")], "global": [MockItem(id="5", name="g", scope_id="global")], } def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: return items_per_scope.get(s.scope_id, []) # Depth 1 should get session + instance result = resolver.resolve( scope=session_scope, fetcher=fetcher, options=ResolutionOptions(max_inheritance_depth=1), ) assert result.total_count == 2 def test_resolve_respects_total_limit( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test that resolution respects total limit.""" scope = manager.create_scope(ScopeLevel.PROJECT, "project") def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: return [ MockItem(id=str(i), name=f"item-{i}", scope_id="project") for i in range(10) ] result = resolver.resolve( scope=scope, fetcher=fetcher, options=ResolutionOptions(total_limit=5), ) assert result.total_count == 5 def test_resolve_deduplicates_by_key( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test deduplication by key field.""" global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") project_scope = manager.create_scope( ScopeLevel.PROJECT, "project", parent=global_scope ) def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: if s.scope_id == "project": return [MockItem(id="1", name="project-ver", scope_id="project")] elif s.scope_id == "global": # Same ID, should be deduplicated return [MockItem(id="1", name="global-ver", scope_id="global")] return [] result = resolver.resolve( scope=project_scope, fetcher=fetcher, options=ResolutionOptions(deduplicate=True, deduplicate_key="id"), ) # Should only have the project version (encountered first) assert result.total_count == 1 assert result.items[0].name == "project-ver" def test_resolve_skips_non_readable_scopes( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test that non-readable scopes are skipped.""" global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") project_scope = manager.create_scope( ScopeLevel.PROJECT, "project", parent=global_scope ) # Set global as non-readable manager.set_policy( global_scope, ScopePolicy( scope_type=ScopeLevel.GLOBAL, scope_id="global", can_read=False, ), ) def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: if s.scope_id == "project": return [MockItem(id="1", name="project-item", scope_id="project")] elif s.scope_id == "global": return [MockItem(id="2", name="global-item", scope_id="global")] return [] result = resolver.resolve( scope=project_scope, fetcher=fetcher, ) # Should only have project item assert result.total_count == 1 assert result.items[0].scope_id == "project" def test_resolve_skips_non_inheritable_scopes( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test that non-inheritable parent scopes stop inheritance.""" global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") project_scope = manager.create_scope( ScopeLevel.PROJECT, "project", parent=global_scope ) # Set global as non-inheritable manager.set_policy( global_scope, ScopePolicy( scope_type=ScopeLevel.GLOBAL, scope_id="global", can_inherit=False, ), ) def fetcher(s: ScopeContext, limit: int) -> list[MockItem]: if s.scope_id == "project": return [MockItem(id="1", name="project-item", scope_id="project")] elif s.scope_id == "global": return [MockItem(id="2", name="global-item", scope_id="global")] return [] result = resolver.resolve( scope=project_scope, fetcher=fetcher, ) # Should only have project item assert result.total_count == 1 def test_get_visible_scopes( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test getting visible scopes.""" global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") project_scope = manager.create_scope( ScopeLevel.PROJECT, "project", parent=global_scope ) agent_scope = manager.create_scope( ScopeLevel.AGENT_TYPE, "agent", parent=project_scope ) visible = resolver.get_visible_scopes(agent_scope) assert len(visible) == 3 assert visible[0].scope_id == "agent" assert visible[1].scope_id == "project" assert visible[2].scope_id == "global" def test_get_visible_scopes_stops_at_non_inheritable( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test that visible scopes stop at non-inheritable parent.""" global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") project_scope = manager.create_scope( ScopeLevel.PROJECT, "project", parent=global_scope ) agent_scope = manager.create_scope( ScopeLevel.AGENT_TYPE, "agent", parent=project_scope ) # Make project non-inheritable manager.set_policy( project_scope, ScopePolicy( scope_type=ScopeLevel.PROJECT, scope_id="project", can_inherit=False, ), ) visible = resolver.get_visible_scopes(agent_scope) # Should stop at project (exclusive) assert len(visible) == 1 assert visible[0].scope_id == "agent" def test_find_write_scope_same_level( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test finding write scope at same level.""" scope = manager.create_scope(ScopeLevel.PROJECT, "project") result = resolver.find_write_scope(ScopeLevel.PROJECT, scope) assert result is not None assert result.scope_id == "project" def test_find_write_scope_ancestor( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test finding write scope in ancestors.""" global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") project_scope = manager.create_scope( ScopeLevel.PROJECT, "project", parent=global_scope ) agent_scope = manager.create_scope( ScopeLevel.AGENT_TYPE, "agent", parent=project_scope ) result = resolver.find_write_scope(ScopeLevel.PROJECT, agent_scope) assert result is not None assert result.scope_id == "project" def test_find_write_scope_not_found( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test finding write scope when not in hierarchy.""" scope = manager.create_scope(ScopeLevel.PROJECT, "project") # Looking for session level, but we're at project result = resolver.find_write_scope(ScopeLevel.SESSION, scope) assert result is None def test_find_write_scope_respects_write_policy( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test that find_write_scope respects write policy.""" global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") project_scope = manager.create_scope( ScopeLevel.PROJECT, "project", parent=global_scope ) # Make project read-only manager.set_policy( project_scope, ScopePolicy( scope_type=ScopeLevel.PROJECT, scope_id="project", can_write=False, ), ) result = resolver.find_write_scope(ScopeLevel.PROJECT, project_scope) assert result is None def test_resolve_scope_from_memory_working( self, resolver: ScopeResolver, ) -> None: """Test resolving scope for working memory.""" project_id = str(uuid4()) session_id = "session-123" scope, level = resolver.resolve_scope_from_memory( memory_type="working", project_id=project_id, session_id=session_id, ) assert scope.scope_type == ScopeLevel.SESSION assert level == ScopeLevel.SESSION def test_resolve_scope_from_memory_episodic( self, resolver: ScopeResolver, ) -> None: """Test resolving scope for episodic memory.""" project_id = str(uuid4()) agent_instance_id = str(uuid4()) scope, level = resolver.resolve_scope_from_memory( memory_type="episodic", project_id=project_id, agent_instance_id=agent_instance_id, ) assert scope.scope_type == ScopeLevel.AGENT_INSTANCE assert level == ScopeLevel.AGENT_INSTANCE def test_resolve_scope_from_memory_semantic( self, resolver: ScopeResolver, ) -> None: """Test resolving scope for semantic memory.""" project_id = str(uuid4()) scope, level = resolver.resolve_scope_from_memory( memory_type="semantic", project_id=project_id, ) assert scope.scope_type == ScopeLevel.PROJECT assert level == ScopeLevel.PROJECT def test_resolve_scope_from_memory_procedural( self, resolver: ScopeResolver, ) -> None: """Test resolving scope for procedural memory.""" project_id = str(uuid4()) agent_type_id = str(uuid4()) scope, level = resolver.resolve_scope_from_memory( memory_type="procedural", project_id=project_id, agent_type_id=agent_type_id, ) assert scope.scope_type == ScopeLevel.AGENT_TYPE assert level == ScopeLevel.AGENT_TYPE def test_validate_write_access_allowed( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test write access validation when allowed.""" scope = manager.create_scope(ScopeLevel.PROJECT, "project") assert resolver.validate_write_access(scope, "semantic") is True def test_validate_write_access_denied_by_policy( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test write access denied by policy.""" scope = manager.create_scope(ScopeLevel.PROJECT, "project") manager.set_policy( scope, ScopePolicy( scope_type=ScopeLevel.PROJECT, scope_id="project", can_write=False, ), ) assert resolver.validate_write_access(scope, "semantic") is False def test_validate_write_access_denied_by_memory_type( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test write access denied by memory type restriction.""" scope = manager.create_scope(ScopeLevel.PROJECT, "project") manager.set_policy( scope, ScopePolicy( scope_type=ScopeLevel.PROJECT, scope_id="project", allowed_memory_types=["episodic"], # Only episodic allowed ), ) assert resolver.validate_write_access(scope, "semantic") is False assert resolver.validate_write_access(scope, "episodic") is True def test_get_scope_chain( self, resolver: ScopeResolver, manager: ScopeManager, ) -> None: """Test getting scope chain.""" global_scope = manager.create_scope(ScopeLevel.GLOBAL, "global") project_scope = manager.create_scope( ScopeLevel.PROJECT, "project", parent=global_scope ) agent_scope = manager.create_scope( ScopeLevel.AGENT_TYPE, "agent", parent=project_scope ) chain = resolver.get_scope_chain(agent_scope) assert len(chain) == 3 assert chain[0] == (ScopeLevel.GLOBAL, "global") assert chain[1] == (ScopeLevel.PROJECT, "project") assert chain[2] == (ScopeLevel.AGENT_TYPE, "agent") class TestScopeFilter: """Tests for ScopeFilter dataclass.""" def test_default_values(self) -> None: """Test default filter values.""" filter_ = ScopeFilter() assert filter_.scope_types is None assert filter_.project_ids is None assert filter_.agent_type_ids is None assert filter_.include_global is True def test_matches_global_scope(self) -> None: """Test matching global scope.""" scope = ScopeContext( scope_type=ScopeLevel.GLOBAL, scope_id="global", ) filter_ = ScopeFilter(include_global=True) assert filter_.matches(scope) is True filter_ = ScopeFilter(include_global=False) assert filter_.matches(scope) is False def test_matches_scope_type(self) -> None: """Test matching by scope type.""" scope = ScopeContext( scope_type=ScopeLevel.PROJECT, scope_id="project-1", ) filter_ = ScopeFilter(scope_types=[ScopeLevel.PROJECT]) assert filter_.matches(scope) is True filter_ = ScopeFilter(scope_types=[ScopeLevel.AGENT_TYPE]) assert filter_.matches(scope) is False def test_matches_project_id(self) -> None: """Test matching by project ID.""" scope = ScopeContext( scope_type=ScopeLevel.PROJECT, scope_id="project-1", ) filter_ = ScopeFilter(project_ids=["project-1", "project-2"]) assert filter_.matches(scope) is True filter_ = ScopeFilter(project_ids=["project-3"]) assert filter_.matches(scope) is False def test_matches_agent_type_id(self) -> None: """Test matching by agent type ID.""" scope = ScopeContext( scope_type=ScopeLevel.AGENT_TYPE, scope_id="agent-1", ) filter_ = ScopeFilter(agent_type_ids=["agent-1"]) assert filter_.matches(scope) is True filter_ = ScopeFilter(agent_type_ids=["agent-2"]) assert filter_.matches(scope) is False class TestGetScopeResolver: """Tests for singleton getter.""" def test_returns_instance(self) -> None: """Test that getter returns instance.""" resolver = get_scope_resolver() assert resolver is not None assert isinstance(resolver, ScopeResolver) def test_returns_same_instance(self) -> None: """Test that getter returns same instance.""" resolver1 = get_scope_resolver() resolver2 = get_scope_resolver() assert resolver1 is resolver2