feat(memory): add semantic memory implementation (Issue #91)
Implements semantic memory with fact storage, retrieval, and verification: Core functionality: - SemanticMemory class for fact storage/retrieval - Fact storage as subject-predicate-object triples - Duplicate detection with reinforcement - Semantic search with text-based fallback - Entity-based retrieval - Confidence scoring and decay - Conflict resolution Supporting modules: - FactExtractor: Pattern-based fact extraction from episodes - FactVerifier: Contradiction detection and reliability scoring Test coverage: - 47 unit tests covering all modules - extraction.py: 99% coverage - verification.py: 95% coverage - memory.py: 78% coverage 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2
backend/tests/unit/services/memory/semantic/__init__.py
Normal file
2
backend/tests/unit/services/memory/semantic/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/semantic/__init__.py
|
||||
"""Unit tests for semantic memory service."""
|
||||
263
backend/tests/unit/services/memory/semantic/test_extraction.py
Normal file
263
backend/tests/unit/services/memory/semantic/test_extraction.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# tests/unit/services/memory/semantic/test_extraction.py
|
||||
"""Unit tests for fact extraction."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.semantic.extraction import (
|
||||
ExtractedFact,
|
||||
ExtractionContext,
|
||||
FactExtractor,
|
||||
get_fact_extractor,
|
||||
)
|
||||
from app.services.memory.types import Episode, Outcome
|
||||
|
||||
|
||||
def create_test_episode(
|
||||
lessons_learned: list[str] | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
task_type: str = "code_review",
|
||||
task_description: str = "Review the authentication module",
|
||||
outcome_details: str = "",
|
||||
) -> Episode:
|
||||
"""Create a test episode for extraction tests."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=None,
|
||||
agent_type_id=None,
|
||||
session_id="test-session",
|
||||
task_type=task_type,
|
||||
task_description=task_description,
|
||||
actions=[],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details=outcome_details,
|
||||
duration_seconds=60.0,
|
||||
tokens_used=500,
|
||||
lessons_learned=lessons_learned or [],
|
||||
importance_score=0.7,
|
||||
embedding=None,
|
||||
occurred_at=datetime.now(UTC),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
class TestExtractedFact:
|
||||
"""Tests for ExtractedFact dataclass."""
|
||||
|
||||
def test_to_fact_create(self) -> None:
|
||||
"""Test converting ExtractedFact to FactCreate."""
|
||||
extracted = ExtractedFact(
|
||||
subject="Python",
|
||||
predicate="uses",
|
||||
object="dynamic typing",
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
fact_create = extracted.to_fact_create(
|
||||
project_id=uuid4(),
|
||||
source_episode_ids=[uuid4()],
|
||||
)
|
||||
|
||||
assert fact_create.subject == "Python"
|
||||
assert fact_create.predicate == "uses"
|
||||
assert fact_create.object == "dynamic typing"
|
||||
assert fact_create.confidence == 0.8
|
||||
|
||||
def test_to_fact_create_defaults(self) -> None:
|
||||
"""Test to_fact_create with default values."""
|
||||
extracted = ExtractedFact(
|
||||
subject="A",
|
||||
predicate="B",
|
||||
object="C",
|
||||
confidence=0.5,
|
||||
)
|
||||
|
||||
fact_create = extracted.to_fact_create()
|
||||
|
||||
assert fact_create.project_id is None
|
||||
assert fact_create.source_episode_ids == []
|
||||
|
||||
|
||||
class TestFactExtractor:
|
||||
"""Tests for FactExtractor class."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self) -> FactExtractor:
|
||||
"""Create a fact extractor."""
|
||||
return FactExtractor()
|
||||
|
||||
def test_extract_from_episode_with_lessons(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting facts from episode with lessons."""
|
||||
episode = create_test_episode(
|
||||
lessons_learned=[
|
||||
"Always validate user input before processing",
|
||||
"Use parameterized queries to prevent SQL injection",
|
||||
]
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
assert len(facts) > 0
|
||||
# Should have lesson_learned predicates
|
||||
lesson_facts = [f for f in facts if f.predicate == "lesson_learned"]
|
||||
assert len(lesson_facts) >= 2
|
||||
|
||||
def test_extract_from_episode_with_always_pattern(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting 'always' pattern lessons."""
|
||||
episode = create_test_episode(
|
||||
lessons_learned=["Always close file handles properly"]
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
best_practices = [f for f in facts if f.predicate == "best_practice"]
|
||||
assert len(best_practices) >= 1
|
||||
assert any("close file handles" in f.object for f in best_practices)
|
||||
|
||||
def test_extract_from_episode_with_never_pattern(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting 'never' pattern lessons."""
|
||||
episode = create_test_episode(
|
||||
lessons_learned=["Never store passwords in plain text"]
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
anti_patterns = [f for f in facts if f.predicate == "anti_pattern"]
|
||||
assert len(anti_patterns) >= 1
|
||||
|
||||
def test_extract_from_episode_with_conditional_pattern(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting conditional lessons."""
|
||||
episode = create_test_episode(
|
||||
lessons_learned=["When handling errors, log the stack trace"]
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
conditional = [f for f in facts if f.predicate == "requires_action"]
|
||||
assert len(conditional) >= 1
|
||||
|
||||
def test_extract_outcome_facts_success(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting facts from successful episode."""
|
||||
episode = create_test_episode(
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="Deployed to production without issues",
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
success_facts = [f for f in facts if f.predicate == "successful_approach"]
|
||||
assert len(success_facts) >= 1
|
||||
|
||||
def test_extract_outcome_facts_failure(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting facts from failed episode."""
|
||||
episode = create_test_episode(
|
||||
outcome=Outcome.FAILURE,
|
||||
outcome_details="Connection timeout during deployment",
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
failure_facts = [f for f in facts if f.predicate == "known_failure_mode"]
|
||||
assert len(failure_facts) >= 1
|
||||
|
||||
def test_extract_from_text_uses_pattern(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting 'uses' pattern from text."""
|
||||
text = "FastAPI uses Starlette for ASGI support."
|
||||
|
||||
facts = extractor.extract_from_text(text)
|
||||
|
||||
assert len(facts) >= 1
|
||||
uses_facts = [f for f in facts if f.predicate == "uses"]
|
||||
assert len(uses_facts) >= 1
|
||||
|
||||
def test_extract_from_text_requires_pattern(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting 'requires' pattern from text."""
|
||||
text = "This feature requires Python 3.10 or higher."
|
||||
|
||||
facts = extractor.extract_from_text(text)
|
||||
|
||||
requires_facts = [f for f in facts if f.predicate == "requires"]
|
||||
assert len(requires_facts) >= 1
|
||||
|
||||
def test_extract_from_text_empty(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting from empty text."""
|
||||
facts = extractor.extract_from_text("")
|
||||
|
||||
assert facts == []
|
||||
|
||||
def test_extract_from_text_short(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting from too-short text."""
|
||||
facts = extractor.extract_from_text("Hi.")
|
||||
|
||||
assert facts == []
|
||||
|
||||
def test_extract_with_context(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extraction with custom context."""
|
||||
episode = create_test_episode(lessons_learned=["Low confidence lesson"])
|
||||
|
||||
context = ExtractionContext(
|
||||
min_confidence=0.9, # High threshold
|
||||
max_facts_per_source=2,
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode, context)
|
||||
|
||||
# Should filter out low confidence facts
|
||||
for fact in facts:
|
||||
assert fact.confidence >= 0.9 or len(facts) <= 2
|
||||
|
||||
|
||||
class TestGetFactExtractor:
|
||||
"""Tests for singleton getter."""
|
||||
|
||||
def test_get_fact_extractor_returns_instance(self) -> None:
|
||||
"""Test that get_fact_extractor returns an instance."""
|
||||
extractor = get_fact_extractor()
|
||||
|
||||
assert extractor is not None
|
||||
assert isinstance(extractor, FactExtractor)
|
||||
|
||||
def test_get_fact_extractor_returns_same_instance(self) -> None:
|
||||
"""Test that get_fact_extractor returns singleton."""
|
||||
extractor1 = get_fact_extractor()
|
||||
extractor2 = get_fact_extractor()
|
||||
|
||||
assert extractor1 is extractor2
|
||||
446
backend/tests/unit/services/memory/semantic/test_memory.py
Normal file
446
backend/tests/unit/services/memory/semantic/test_memory.py
Normal file
@@ -0,0 +1,446 @@
|
||||
# tests/unit/services/memory/semantic/test_memory.py
|
||||
"""Unit tests for SemanticMemory class."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.semantic.memory import SemanticMemory
|
||||
from app.services.memory.types import FactCreate
|
||||
|
||||
|
||||
def create_mock_fact_model(
|
||||
project_id=None,
|
||||
subject="FastAPI",
|
||||
predicate="uses",
|
||||
obj="Starlette",
|
||||
confidence=0.8,
|
||||
):
|
||||
"""Create a mock fact model for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = uuid4()
|
||||
mock.project_id = project_id
|
||||
mock.subject = subject
|
||||
mock.predicate = predicate
|
||||
mock.object = obj
|
||||
mock.confidence = confidence
|
||||
mock.source_episode_ids = []
|
||||
mock.first_learned = datetime.now(UTC)
|
||||
mock.last_reinforced = datetime.now(UTC)
|
||||
mock.reinforcement_count = 1
|
||||
mock.embedding = None
|
||||
mock.created_at = datetime.now(UTC)
|
||||
mock.updated_at = datetime.now(UTC)
|
||||
return mock
|
||||
|
||||
|
||||
class TestSemanticMemoryInit:
|
||||
"""Tests for SemanticMemory initialization."""
|
||||
|
||||
def test_init_creates_memory(self) -> None:
|
||||
"""Test that init creates memory instance."""
|
||||
mock_session = AsyncMock()
|
||||
memory = SemanticMemory(session=mock_session)
|
||||
|
||||
assert memory._session is mock_session
|
||||
|
||||
def test_init_with_embedding_generator(self) -> None:
|
||||
"""Test init with embedding generator."""
|
||||
mock_session = AsyncMock()
|
||||
mock_embedding_gen = AsyncMock()
|
||||
memory = SemanticMemory(
|
||||
session=mock_session, embedding_generator=mock_embedding_gen
|
||||
)
|
||||
|
||||
assert memory._embedding_generator is mock_embedding_gen
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_factory_method(self) -> None:
|
||||
"""Test create factory method."""
|
||||
mock_session = AsyncMock()
|
||||
memory = await SemanticMemory.create(session=mock_session)
|
||||
|
||||
assert memory is not None
|
||||
assert memory._session is mock_session
|
||||
|
||||
|
||||
class TestSemanticMemoryStoreFact:
|
||||
"""Tests for fact storage methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
# Mock no existing fact found
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
|
||||
"""Create a SemanticMemory instance."""
|
||||
return SemanticMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_new_fact(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test storing a new fact."""
|
||||
fact_data = FactCreate(
|
||||
subject="Python",
|
||||
predicate="is_a",
|
||||
object="programming language",
|
||||
confidence=0.9,
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
result = await memory.store_fact(fact_data)
|
||||
|
||||
assert result.subject == "Python"
|
||||
assert result.predicate == "is_a"
|
||||
assert result.object == "programming language"
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_fact_reinforces_existing(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that storing duplicate fact reinforces existing."""
|
||||
# Mock existing fact found - needs to be found first
|
||||
existing_mock = create_mock_fact_model(confidence=0.7)
|
||||
find_result = MagicMock()
|
||||
find_result.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
# Second find for reinforce_fact
|
||||
find_for_reinforce = MagicMock()
|
||||
find_for_reinforce.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
# Mock update result - returns the updated mock
|
||||
updated_mock = create_mock_fact_model(confidence=0.8)
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one.return_value = updated_mock
|
||||
|
||||
mock_session.execute.side_effect = [
|
||||
find_result, # _find_existing_fact
|
||||
find_for_reinforce, # reinforce_fact query
|
||||
update_result, # reinforce_fact update
|
||||
]
|
||||
|
||||
fact_data = FactCreate(
|
||||
subject="FastAPI",
|
||||
predicate="uses",
|
||||
object="Starlette",
|
||||
)
|
||||
|
||||
_ = await memory.store_fact(fact_data)
|
||||
|
||||
# Should have called execute three times (find + find + update)
|
||||
assert mock_session.execute.call_count == 3
|
||||
|
||||
|
||||
class TestSemanticMemorySearch:
|
||||
"""Tests for fact search methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
|
||||
"""Create a SemanticMemory instance."""
|
||||
return SemanticMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_facts(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
) -> None:
|
||||
"""Test searching for facts."""
|
||||
results = await memory.search_facts("Python programming")
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_facts_with_project_filter(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
) -> None:
|
||||
"""Test searching for facts with project filter."""
|
||||
project_id = uuid4()
|
||||
results = await memory.search_facts("Python", project_id=project_id)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_entity(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
) -> None:
|
||||
"""Test getting facts by entity."""
|
||||
results = await memory.get_by_entity("FastAPI")
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_subject(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
) -> None:
|
||||
"""Test getting facts by subject."""
|
||||
results = await memory.get_by_subject("Python")
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_not_found(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test get_by_id returns None when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSemanticMemoryReinforcement:
|
||||
"""Tests for fact reinforcement."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
|
||||
"""Create a SemanticMemory instance."""
|
||||
return SemanticMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reinforce_fact(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test reinforcing a fact."""
|
||||
existing_mock = create_mock_fact_model(confidence=0.7)
|
||||
|
||||
# First query: find fact
|
||||
find_result = MagicMock()
|
||||
find_result.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
# Second query: update fact
|
||||
updated_mock = create_mock_fact_model(confidence=0.8)
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one.return_value = updated_mock
|
||||
|
||||
mock_session.execute.side_effect = [find_result, update_result]
|
||||
|
||||
result = await memory.reinforce_fact(existing_mock.id, confidence_boost=0.1)
|
||||
|
||||
assert result.confidence == 0.8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reinforce_fact_not_found(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test reinforcing a non-existent fact raises error."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
with pytest.raises(ValueError, match="Fact not found"):
|
||||
await memory.reinforce_fact(uuid4())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprecate_fact(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test deprecating a fact."""
|
||||
existing_mock = create_mock_fact_model(confidence=0.8)
|
||||
|
||||
find_result = MagicMock()
|
||||
find_result.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
deprecated_mock = create_mock_fact_model(confidence=0.0)
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one_or_none.return_value = deprecated_mock
|
||||
|
||||
mock_session.execute.side_effect = [find_result, update_result]
|
||||
|
||||
result = await memory.deprecate_fact(existing_mock.id, reason="Outdated")
|
||||
|
||||
assert result is not None
|
||||
assert result.confidence == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprecate_fact_not_found(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test deprecating non-existent fact returns None."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.deprecate_fact(uuid4(), reason="Test")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSemanticMemoryConflictResolution:
|
||||
"""Tests for conflict resolution."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
|
||||
"""Create a SemanticMemory instance."""
|
||||
return SemanticMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_conflict_empty_list(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
) -> None:
|
||||
"""Test resolving conflict with empty list."""
|
||||
result = await memory.resolve_conflict([])
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_conflict_keeps_highest_confidence(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that conflict resolution keeps highest confidence fact."""
|
||||
fact_low = create_mock_fact_model(confidence=0.5)
|
||||
fact_high = create_mock_fact_model(confidence=0.9)
|
||||
|
||||
# Mock finding the facts
|
||||
find_result = MagicMock()
|
||||
find_result.scalars.return_value.all.return_value = [fact_low, fact_high]
|
||||
|
||||
# Mock deprecation (find + update)
|
||||
find_one_result = MagicMock()
|
||||
find_one_result.scalar_one_or_none.return_value = fact_low
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one_or_none.return_value = fact_low
|
||||
|
||||
mock_session.execute.side_effect = [find_result, find_one_result, update_result]
|
||||
|
||||
result = await memory.resolve_conflict([fact_low.id, fact_high.id])
|
||||
|
||||
assert result is not None
|
||||
assert result.confidence == 0.9
|
||||
|
||||
|
||||
class TestSemanticMemoryStats:
|
||||
"""Tests for statistics methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
|
||||
"""Create a SemanticMemory instance."""
|
||||
return SemanticMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_empty(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test getting stats for empty project."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
stats = await memory.get_stats(uuid4())
|
||||
|
||||
assert stats["total_facts"] == 0
|
||||
assert stats["avg_confidence"] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test counting facts."""
|
||||
facts = [create_mock_fact_model() for _ in range(5)]
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = facts
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
count = await memory.count(uuid4())
|
||||
|
||||
assert count == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test deleting a fact."""
|
||||
existing_mock = create_mock_fact_model()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = existing_mock
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session.delete = AsyncMock()
|
||||
|
||||
result = await memory.delete(existing_mock.id)
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_not_found(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test deleting non-existent fact."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.delete(uuid4())
|
||||
|
||||
assert result is False
|
||||
298
backend/tests/unit/services/memory/semantic/test_verification.py
Normal file
298
backend/tests/unit/services/memory/semantic/test_verification.py
Normal file
@@ -0,0 +1,298 @@
|
||||
# tests/unit/services/memory/semantic/test_verification.py
|
||||
"""Unit tests for fact verification."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.semantic.verification import (
|
||||
FactConflict,
|
||||
FactVerifier,
|
||||
VerificationResult,
|
||||
)
|
||||
|
||||
|
||||
def create_mock_fact_model(
|
||||
subject="FastAPI",
|
||||
predicate="uses",
|
||||
obj="Starlette",
|
||||
confidence=0.8,
|
||||
project_id=None,
|
||||
):
|
||||
"""Create a mock fact model for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = uuid4()
|
||||
mock.project_id = project_id
|
||||
mock.subject = subject
|
||||
mock.predicate = predicate
|
||||
mock.object = obj
|
||||
mock.confidence = confidence
|
||||
mock.source_episode_ids = []
|
||||
mock.first_learned = datetime.now(UTC)
|
||||
mock.last_reinforced = datetime.now(UTC)
|
||||
mock.reinforcement_count = 1
|
||||
mock.embedding = None
|
||||
mock.created_at = datetime.now(UTC)
|
||||
mock.updated_at = datetime.now(UTC)
|
||||
return mock
|
||||
|
||||
|
||||
class TestFactConflict:
|
||||
"""Tests for FactConflict dataclass."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting conflict to dictionary."""
|
||||
conflict = FactConflict(
|
||||
fact_a_id=uuid4(),
|
||||
fact_b_id=uuid4(),
|
||||
conflict_type="contradiction",
|
||||
description="Test conflict",
|
||||
suggested_resolution="Keep higher confidence",
|
||||
)
|
||||
|
||||
result = conflict.to_dict()
|
||||
|
||||
assert "fact_a_id" in result
|
||||
assert "fact_b_id" in result
|
||||
assert result["conflict_type"] == "contradiction"
|
||||
assert result["description"] == "Test conflict"
|
||||
|
||||
|
||||
class TestVerificationResult:
|
||||
"""Tests for VerificationResult dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default values."""
|
||||
result = VerificationResult(is_valid=True)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.confidence_adjustment == 0.0
|
||||
assert result.conflicts == []
|
||||
assert result.supporting_facts == []
|
||||
assert result.messages == []
|
||||
|
||||
|
||||
class TestFactVerifier:
|
||||
"""Tests for FactVerifier class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def verifier(self, mock_session: AsyncMock) -> FactVerifier:
|
||||
"""Create a fact verifier."""
|
||||
return FactVerifier(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_fact_valid(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
) -> None:
|
||||
"""Test verifying a valid fact with no conflicts."""
|
||||
result = await verifier.verify_fact(
|
||||
subject="Python",
|
||||
predicate="is_a",
|
||||
obj="programming language",
|
||||
)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert len(result.conflicts) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_fact_with_support(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test verifying a fact with supporting facts."""
|
||||
# Mock finding supporting facts
|
||||
supporting = [create_mock_fact_model()]
|
||||
|
||||
# First query: contradictions (empty)
|
||||
contradiction_result = MagicMock()
|
||||
contradiction_result.scalars.return_value.all.return_value = []
|
||||
|
||||
# Second query: supporting facts
|
||||
support_result = MagicMock()
|
||||
support_result.scalars.return_value.all.return_value = supporting
|
||||
|
||||
mock_session.execute.side_effect = [contradiction_result, support_result]
|
||||
|
||||
result = await verifier.verify_fact(
|
||||
subject="Python",
|
||||
predicate="uses",
|
||||
obj="dynamic typing",
|
||||
)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert len(result.supporting_facts) >= 1
|
||||
assert result.confidence_adjustment > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_fact_with_contradiction(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test verifying a fact with contradictions."""
|
||||
# Mock finding contradicting fact
|
||||
contradicting = create_mock_fact_model(
|
||||
subject="Python",
|
||||
predicate="does_not_use",
|
||||
obj="static typing",
|
||||
)
|
||||
|
||||
contradiction_result = MagicMock()
|
||||
contradiction_result.scalars.return_value.all.return_value = [contradicting]
|
||||
|
||||
support_result = MagicMock()
|
||||
support_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session.execute.side_effect = [contradiction_result, support_result]
|
||||
|
||||
result = await verifier.verify_fact(
|
||||
subject="Python",
|
||||
predicate="uses",
|
||||
obj="static typing",
|
||||
)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert len(result.conflicts) >= 1
|
||||
assert result.confidence_adjustment < 0
|
||||
|
||||
def test_get_opposite_predicates(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
) -> None:
|
||||
"""Test getting opposite predicates."""
|
||||
opposites = verifier._get_opposite_predicates("uses")
|
||||
|
||||
assert "does_not_use" in opposites
|
||||
|
||||
def test_get_opposite_predicates_unknown(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
) -> None:
|
||||
"""Test getting opposites for unknown predicate."""
|
||||
opposites = verifier._get_opposite_predicates("unknown_predicate")
|
||||
|
||||
assert opposites == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_all_conflicts_empty(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test finding all conflicts in empty fact base."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
conflicts = await verifier.find_all_conflicts()
|
||||
|
||||
assert conflicts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_all_conflicts_no_conflicts(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test finding conflicts when there are none."""
|
||||
# Two facts with different subjects
|
||||
fact1 = create_mock_fact_model(subject="Python", predicate="uses")
|
||||
fact2 = create_mock_fact_model(subject="JavaScript", predicate="uses")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [fact1, fact2]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
conflicts = await verifier.find_all_conflicts()
|
||||
|
||||
assert conflicts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_all_conflicts_with_contradiction(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test finding contradicting facts."""
|
||||
# Two contradicting facts
|
||||
fact1 = create_mock_fact_model(
|
||||
subject="Python",
|
||||
predicate="best_practice",
|
||||
obj="Use type hints",
|
||||
)
|
||||
fact2 = create_mock_fact_model(
|
||||
subject="Python",
|
||||
predicate="anti_pattern",
|
||||
obj="Use type hints",
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [fact1, fact2]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
conflicts = await verifier.find_all_conflicts()
|
||||
|
||||
assert len(conflicts) == 1
|
||||
assert conflicts[0].conflict_type == "contradiction"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_fact_reliability_score_not_found(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test reliability score for non-existent fact."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
score = await verifier.get_fact_reliability_score(uuid4())
|
||||
|
||||
assert score == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_fact_reliability_score(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test calculating reliability score."""
|
||||
fact = create_mock_fact_model(confidence=0.8)
|
||||
fact.reinforcement_count = 5
|
||||
|
||||
# Query 1: Get fact
|
||||
fact_result = MagicMock()
|
||||
fact_result.scalar_one_or_none.return_value = fact
|
||||
|
||||
# Query 2: Supporting facts
|
||||
support_result = MagicMock()
|
||||
support_result.scalars.return_value.all.return_value = []
|
||||
|
||||
# Query 3: Contradictions
|
||||
conflict_result = MagicMock()
|
||||
conflict_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session.execute.side_effect = [
|
||||
fact_result,
|
||||
support_result,
|
||||
conflict_result,
|
||||
]
|
||||
|
||||
score = await verifier.get_fact_reliability_score(fact.id)
|
||||
|
||||
# Score should be >= confidence (0.8) due to reinforcement bonus
|
||||
assert score >= 0.8
|
||||
assert score <= 1.0
|
||||
Reference in New Issue
Block a user