forked from cardosofelipe/fast-next-template
feat(memory): add semantic memory implementation (Issue #91)
Implements semantic memory with fact storage, retrieval, and verification: Core functionality: - SemanticMemory class for fact storage/retrieval - Fact storage as subject-predicate-object triples - Duplicate detection with reinforcement - Semantic search with text-based fallback - Entity-based retrieval - Confidence scoring and decay - Conflict resolution Supporting modules: - FactExtractor: Pattern-based fact extraction from episodes - FactVerifier: Contradiction detection and reliability scoring Test coverage: - 47 unit tests covering all modules - extraction.py: 99% coverage - verification.py: 95% coverage - memory.py: 78% coverage 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
# app/services/memory/semantic/__init__.py
|
||||
"""
|
||||
Semantic Memory
|
||||
|
||||
@@ -5,4 +6,22 @@ Fact storage with triple format (subject, predicate, object)
|
||||
and semantic search capabilities.
|
||||
"""
|
||||
|
||||
# Will be populated in #91
|
||||
from .extraction import (
|
||||
ExtractedFact,
|
||||
ExtractionContext,
|
||||
FactExtractor,
|
||||
get_fact_extractor,
|
||||
)
|
||||
from .memory import SemanticMemory
|
||||
from .verification import FactConflict, FactVerifier, VerificationResult
|
||||
|
||||
__all__ = [
|
||||
"ExtractedFact",
|
||||
"ExtractionContext",
|
||||
"FactConflict",
|
||||
"FactExtractor",
|
||||
"FactVerifier",
|
||||
"SemanticMemory",
|
||||
"VerificationResult",
|
||||
"get_fact_extractor",
|
||||
]
|
||||
|
||||
313
backend/app/services/memory/semantic/extraction.py
Normal file
313
backend/app/services/memory/semantic/extraction.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# app/services/memory/semantic/extraction.py
|
||||
"""
|
||||
Fact Extraction from Episodes.
|
||||
|
||||
Provides utilities for extracting semantic facts (subject-predicate-object triples)
|
||||
from episodic memories and other text sources.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from app.services.memory.types import Episode, FactCreate, Outcome
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionContext:
|
||||
"""Context for fact extraction."""
|
||||
|
||||
project_id: Any | None = None
|
||||
source_episode_id: Any | None = None
|
||||
min_confidence: float = 0.5
|
||||
max_facts_per_source: int = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedFact:
|
||||
"""A fact extracted from text before storage."""
|
||||
|
||||
subject: str
|
||||
predicate: str
|
||||
object: str
|
||||
confidence: float
|
||||
source_text: str = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_fact_create(
|
||||
self,
|
||||
project_id: Any | None = None,
|
||||
source_episode_ids: list[Any] | None = None,
|
||||
) -> FactCreate:
|
||||
"""Convert to FactCreate for storage."""
|
||||
return FactCreate(
|
||||
subject=self.subject,
|
||||
predicate=self.predicate,
|
||||
object=self.object,
|
||||
confidence=self.confidence,
|
||||
project_id=project_id,
|
||||
source_episode_ids=source_episode_ids or [],
|
||||
)
|
||||
|
||||
|
||||
class FactExtractor:
|
||||
"""
|
||||
Extracts facts from episodes and text.
|
||||
|
||||
This is a rule-based extractor. In production, this would be
|
||||
replaced or augmented with LLM-based extraction for better accuracy.
|
||||
"""
|
||||
|
||||
# Common predicates we can detect
|
||||
PREDICATE_PATTERNS: ClassVar[dict[str, str]] = {
|
||||
"uses": r"(?:uses?|using|utilizes?)",
|
||||
"requires": r"(?:requires?|needs?|depends?\s+on)",
|
||||
"is_a": r"(?:is\s+a|is\s+an|are\s+a|are)",
|
||||
"has": r"(?:has|have|contains?)",
|
||||
"part_of": r"(?:part\s+of|belongs?\s+to|member\s+of)",
|
||||
"causes": r"(?:causes?|leads?\s+to|results?\s+in)",
|
||||
"prevents": r"(?:prevents?|avoids?|stops?)",
|
||||
"solves": r"(?:solves?|fixes?|resolves?)",
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize extractor."""
|
||||
self._compiled_patterns = {
|
||||
pred: re.compile(pattern, re.IGNORECASE)
|
||||
for pred, pattern in self.PREDICATE_PATTERNS.items()
|
||||
}
|
||||
|
||||
def extract_from_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
context: ExtractionContext | None = None,
|
||||
) -> list[ExtractedFact]:
|
||||
"""
|
||||
Extract facts from an episode.
|
||||
|
||||
Args:
|
||||
episode: Episode to extract from
|
||||
context: Optional extraction context
|
||||
|
||||
Returns:
|
||||
List of extracted facts
|
||||
"""
|
||||
ctx = context or ExtractionContext()
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
# Extract from task description
|
||||
task_facts = self._extract_from_text(
|
||||
episode.task_description,
|
||||
source_prefix=episode.task_type,
|
||||
)
|
||||
facts.extend(task_facts)
|
||||
|
||||
# Extract from lessons learned
|
||||
for lesson in episode.lessons_learned:
|
||||
lesson_facts = self._extract_from_lesson(lesson, episode)
|
||||
facts.extend(lesson_facts)
|
||||
|
||||
# Extract outcome-based facts
|
||||
outcome_facts = self._extract_outcome_facts(episode)
|
||||
facts.extend(outcome_facts)
|
||||
|
||||
# Limit and filter
|
||||
facts = [f for f in facts if f.confidence >= ctx.min_confidence]
|
||||
facts = facts[: ctx.max_facts_per_source]
|
||||
|
||||
logger.debug(f"Extracted {len(facts)} facts from episode {episode.id}")
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_from_text(
|
||||
self,
|
||||
text: str,
|
||||
source_prefix: str = "",
|
||||
) -> list[ExtractedFact]:
|
||||
"""Extract facts from free-form text using pattern matching."""
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
if not text or len(text) < 10:
|
||||
return facts
|
||||
|
||||
# Split into sentences
|
||||
sentences = re.split(r"[.!?]+", text)
|
||||
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if len(sentence) < 10:
|
||||
continue
|
||||
|
||||
# Try to match predicate patterns
|
||||
for predicate, pattern in self._compiled_patterns.items():
|
||||
match = pattern.search(sentence)
|
||||
if match:
|
||||
# Extract subject (text before predicate)
|
||||
subject = sentence[: match.start()].strip()
|
||||
# Extract object (text after predicate)
|
||||
obj = sentence[match.end() :].strip()
|
||||
|
||||
if len(subject) > 2 and len(obj) > 2:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=subject[:200], # Limit length
|
||||
predicate=predicate,
|
||||
object=obj[:500],
|
||||
confidence=0.6, # Medium confidence for pattern matching
|
||||
source_text=sentence,
|
||||
)
|
||||
)
|
||||
break # One fact per sentence
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_from_lesson(
|
||||
self,
|
||||
lesson: str,
|
||||
episode: Episode,
|
||||
) -> list[ExtractedFact]:
|
||||
"""Extract facts from a lesson learned."""
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
if not lesson or len(lesson) < 10:
|
||||
return facts
|
||||
|
||||
# Lessons are typically in the form "Always do X" or "Never do Y"
|
||||
# or "When X, do Y"
|
||||
|
||||
# Direct lesson fact
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="lesson_learned",
|
||||
object=lesson,
|
||||
confidence=0.8, # High confidence for explicit lessons
|
||||
source_text=lesson,
|
||||
metadata={"outcome": episode.outcome.value},
|
||||
)
|
||||
)
|
||||
|
||||
# Extract conditional patterns
|
||||
conditional_match = re.match(
|
||||
r"(?:when|if)\s+(.+?),\s*(.+)",
|
||||
lesson,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if conditional_match:
|
||||
condition, action = conditional_match.groups()
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=condition.strip(),
|
||||
predicate="requires_action",
|
||||
object=action.strip(),
|
||||
confidence=0.7,
|
||||
source_text=lesson,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract "always/never" patterns
|
||||
always_match = re.match(
|
||||
r"(?:always)\s+(.+)",
|
||||
lesson,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if always_match:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="best_practice",
|
||||
object=always_match.group(1).strip(),
|
||||
confidence=0.85,
|
||||
source_text=lesson,
|
||||
)
|
||||
)
|
||||
|
||||
never_match = re.match(
|
||||
r"(?:never|avoid)\s+(.+)",
|
||||
lesson,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if never_match:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="anti_pattern",
|
||||
object=never_match.group(1).strip(),
|
||||
confidence=0.85,
|
||||
source_text=lesson,
|
||||
)
|
||||
)
|
||||
|
||||
return facts
|
||||
|
||||
def _extract_outcome_facts(
|
||||
self,
|
||||
episode: Episode,
|
||||
) -> list[ExtractedFact]:
|
||||
"""Extract facts based on episode outcome."""
|
||||
facts: list[ExtractedFact] = []
|
||||
|
||||
# Create fact based on outcome
|
||||
if episode.outcome == Outcome.SUCCESS:
|
||||
if episode.outcome_details:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="successful_approach",
|
||||
object=episode.outcome_details[:500],
|
||||
confidence=0.75,
|
||||
source_text=episode.outcome_details,
|
||||
)
|
||||
)
|
||||
elif episode.outcome == Outcome.FAILURE:
|
||||
if episode.outcome_details:
|
||||
facts.append(
|
||||
ExtractedFact(
|
||||
subject=episode.task_type,
|
||||
predicate="known_failure_mode",
|
||||
object=episode.outcome_details[:500],
|
||||
confidence=0.8, # High confidence for failures
|
||||
source_text=episode.outcome_details,
|
||||
)
|
||||
)
|
||||
|
||||
return facts
|
||||
|
||||
def extract_from_text(
|
||||
self,
|
||||
text: str,
|
||||
context: ExtractionContext | None = None,
|
||||
) -> list[ExtractedFact]:
|
||||
"""
|
||||
Extract facts from arbitrary text.
|
||||
|
||||
Args:
|
||||
text: Text to extract from
|
||||
context: Optional extraction context
|
||||
|
||||
Returns:
|
||||
List of extracted facts
|
||||
"""
|
||||
ctx = context or ExtractionContext()
|
||||
|
||||
facts = self._extract_from_text(text)
|
||||
|
||||
# Filter by confidence
|
||||
facts = [f for f in facts if f.confidence >= ctx.min_confidence]
|
||||
|
||||
return facts[: ctx.max_facts_per_source]
|
||||
|
||||
|
||||
# Singleton extractor instance
|
||||
_extractor: FactExtractor | None = None
|
||||
|
||||
|
||||
def get_fact_extractor() -> FactExtractor:
|
||||
"""Get the singleton fact extractor instance."""
|
||||
global _extractor
|
||||
if _extractor is None:
|
||||
_extractor = FactExtractor()
|
||||
return _extractor
|
||||
742
backend/app/services/memory/semantic/memory.py
Normal file
742
backend/app/services/memory/semantic/memory.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# app/services/memory/semantic/memory.py
|
||||
"""
|
||||
Semantic Memory Implementation.
|
||||
|
||||
Provides fact storage and retrieval using subject-predicate-object triples.
|
||||
Supports semantic search, confidence scoring, and fact reinforcement.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, desc, or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.fact import Fact as FactModel
|
||||
from app.services.memory.config import get_memory_settings
|
||||
from app.services.memory.types import Episode, Fact, FactCreate, RetrievalResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _model_to_fact(model: FactModel) -> Fact:
|
||||
"""Convert SQLAlchemy model to Fact dataclass."""
|
||||
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
|
||||
# they return actual values. We use type: ignore to handle this mismatch.
|
||||
return Fact(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
object=model.object, # type: ignore[arg-type]
|
||||
confidence=model.confidence, # type: ignore[arg-type]
|
||||
source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type]
|
||||
first_learned=model.first_learned, # type: ignore[arg-type]
|
||||
last_reinforced=model.last_reinforced, # type: ignore[arg-type]
|
||||
reinforcement_count=model.reinforcement_count, # type: ignore[arg-type]
|
||||
embedding=None, # Don't expose raw embedding
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
class SemanticMemory:
|
||||
"""
|
||||
Semantic Memory Service.
|
||||
|
||||
Provides fact storage and retrieval:
|
||||
- Store facts as subject-predicate-object triples
|
||||
- Semantic search over facts
|
||||
- Entity-based retrieval
|
||||
- Confidence scoring and decay
|
||||
- Fact reinforcement on repeated learning
|
||||
- Conflict resolution
|
||||
|
||||
Performance target: <100ms P95 for retrieval
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize semantic memory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator for semantic search
|
||||
"""
|
||||
self._session = session
|
||||
self._embedding_generator = embedding_generator
|
||||
self._settings = get_memory_settings()
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
session: AsyncSession,
|
||||
embedding_generator: Any | None = None,
|
||||
) -> "SemanticMemory":
|
||||
"""
|
||||
Factory method to create SemanticMemory.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
embedding_generator: Optional embedding generator
|
||||
|
||||
Returns:
|
||||
Configured SemanticMemory instance
|
||||
"""
|
||||
return cls(session=session, embedding_generator=embedding_generator)
|
||||
|
||||
# =========================================================================
|
||||
# Fact Storage
|
||||
# =========================================================================
|
||||
|
||||
async def store_fact(self, fact: FactCreate) -> Fact:
|
||||
"""
|
||||
Store a new fact or reinforce an existing one.
|
||||
|
||||
If a fact with the same triple (subject, predicate, object) exists
|
||||
in the same scope, it will be reinforced instead of duplicated.
|
||||
|
||||
Args:
|
||||
fact: Fact data to store
|
||||
|
||||
Returns:
|
||||
The created or reinforced fact
|
||||
"""
|
||||
# Check for existing fact with same triple
|
||||
existing = await self._find_existing_fact(
|
||||
project_id=fact.project_id,
|
||||
subject=fact.subject,
|
||||
predicate=fact.predicate,
|
||||
object=fact.object,
|
||||
)
|
||||
|
||||
if existing is not None:
|
||||
# Reinforce existing fact
|
||||
return await self.reinforce_fact(
|
||||
existing.id, # type: ignore[arg-type]
|
||||
source_episode_ids=fact.source_episode_ids,
|
||||
)
|
||||
|
||||
# Create new fact
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Generate embedding if possible
|
||||
embedding = None
|
||||
if self._embedding_generator is not None:
|
||||
embedding_text = self._create_embedding_text(fact)
|
||||
embedding = await self._embedding_generator.generate(embedding_text)
|
||||
|
||||
model = FactModel(
|
||||
project_id=fact.project_id,
|
||||
subject=fact.subject,
|
||||
predicate=fact.predicate,
|
||||
object=fact.object,
|
||||
confidence=fact.confidence,
|
||||
source_episode_ids=fact.source_episode_ids,
|
||||
first_learned=now,
|
||||
last_reinforced=now,
|
||||
reinforcement_count=1,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
|
||||
logger.info(
|
||||
f"Stored new fact: {fact.subject} - {fact.predicate} - {fact.object[:50]}..."
|
||||
)
|
||||
|
||||
return _model_to_fact(model)
|
||||
|
||||
async def _find_existing_fact(
|
||||
self,
|
||||
project_id: UUID | None,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
object: str,
|
||||
) -> FactModel | None:
|
||||
"""Find an existing fact with the same triple in the same scope."""
|
||||
query = select(FactModel).where(
|
||||
and_(
|
||||
FactModel.subject == subject,
|
||||
FactModel.predicate == predicate,
|
||||
FactModel.object == object,
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(FactModel.project_id == project_id)
|
||||
else:
|
||||
query = query.where(FactModel.project_id.is_(None))
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
def _create_embedding_text(self, fact: FactCreate) -> str:
|
||||
"""Create text for embedding from fact data."""
|
||||
return f"{fact.subject} {fact.predicate} {fact.object}"
|
||||
|
||||
# =========================================================================
|
||||
# Fact Retrieval
|
||||
# =========================================================================
|
||||
|
||||
async def search_facts(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 10,
|
||||
min_confidence: float | None = None,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Search for facts semantically similar to the query.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
project_id: Optional project to search within
|
||||
limit: Maximum results
|
||||
min_confidence: Optional minimum confidence filter
|
||||
|
||||
Returns:
|
||||
List of matching facts
|
||||
"""
|
||||
result = await self._search_facts_with_metadata(
|
||||
query=query,
|
||||
project_id=project_id,
|
||||
limit=limit,
|
||||
min_confidence=min_confidence,
|
||||
)
|
||||
return result.items
|
||||
|
||||
async def _search_facts_with_metadata(
|
||||
self,
|
||||
query: str,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 10,
|
||||
min_confidence: float | None = None,
|
||||
) -> RetrievalResult[Fact]:
|
||||
"""Search facts with full result metadata."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
min_conf = min_confidence or self._settings.semantic_min_confidence
|
||||
|
||||
# Build base query
|
||||
stmt = (
|
||||
select(FactModel)
|
||||
.where(FactModel.confidence >= min_conf)
|
||||
.order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# Apply project filter
|
||||
if project_id is not None:
|
||||
# Include both project-specific and global facts
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: Implement proper vector similarity search when pgvector is integrated
|
||||
# For now, do text-based search on subject/predicate/object
|
||||
search_terms = query.lower().split()
|
||||
if search_terms:
|
||||
conditions = []
|
||||
for term in search_terms[:5]: # Limit to 5 terms
|
||||
term_pattern = f"%{term}%"
|
||||
conditions.append(
|
||||
or_(
|
||||
FactModel.subject.ilike(term_pattern),
|
||||
FactModel.predicate.ilike(term_pattern),
|
||||
FactModel.object.ilike(term_pattern),
|
||||
)
|
||||
)
|
||||
if conditions:
|
||||
stmt = stmt.where(or_(*conditions))
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
return RetrievalResult(
|
||||
items=[_model_to_fact(m) for m in models],
|
||||
total_count=len(models),
|
||||
query=query,
|
||||
retrieval_type="semantic",
|
||||
latency_ms=latency_ms,
|
||||
metadata={"min_confidence": min_conf},
|
||||
)
|
||||
|
||||
async def get_by_entity(
|
||||
self,
|
||||
entity: str,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Get facts related to an entity (as subject or object).
|
||||
|
||||
Args:
|
||||
entity: Entity to search for
|
||||
project_id: Optional project to search within
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of facts mentioning the entity
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
stmt = (
|
||||
select(FactModel)
|
||||
.where(
|
||||
or_(
|
||||
FactModel.subject.ilike(f"%{entity}%"),
|
||||
FactModel.object.ilike(f"%{entity}%"),
|
||||
)
|
||||
)
|
||||
.order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
latency_ms = (time.perf_counter() - start_time) * 1000
|
||||
logger.debug(
|
||||
f"get_by_entity({entity}) returned {len(models)} facts in {latency_ms:.1f}ms"
|
||||
)
|
||||
|
||||
return [_model_to_fact(m) for m in models]
|
||||
|
||||
async def get_by_subject(
|
||||
self,
|
||||
subject: str,
|
||||
project_id: UUID | None = None,
|
||||
predicate: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Get facts with a specific subject.
|
||||
|
||||
Args:
|
||||
subject: Subject to search for
|
||||
project_id: Optional project to search within
|
||||
predicate: Optional predicate filter
|
||||
limit: Maximum results
|
||||
|
||||
Returns:
|
||||
List of facts with matching subject
|
||||
"""
|
||||
stmt = (
|
||||
select(FactModel)
|
||||
.where(FactModel.subject == subject)
|
||||
.order_by(desc(FactModel.confidence))
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
if predicate is not None:
|
||||
stmt = stmt.where(FactModel.predicate == predicate)
|
||||
|
||||
if project_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
return [_model_to_fact(m) for m in models]
|
||||
|
||||
async def get_by_id(self, fact_id: UUID) -> Fact | None:
|
||||
"""Get a fact by ID."""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
return _model_to_fact(model) if model else None
|
||||
|
||||
# =========================================================================
|
||||
# Fact Reinforcement
|
||||
# =========================================================================
|
||||
|
||||
async def reinforce_fact(
|
||||
self,
|
||||
fact_id: UUID,
|
||||
confidence_boost: float = 0.1,
|
||||
source_episode_ids: list[UUID] | None = None,
|
||||
) -> Fact:
|
||||
"""
|
||||
Reinforce a fact, increasing its confidence.
|
||||
|
||||
Args:
|
||||
fact_id: Fact to reinforce
|
||||
confidence_boost: Amount to increase confidence (default 0.1)
|
||||
source_episode_ids: Additional source episodes
|
||||
|
||||
Returns:
|
||||
Updated fact
|
||||
|
||||
Raises:
|
||||
ValueError: If fact not found
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
raise ValueError(f"Fact not found: {fact_id}")
|
||||
|
||||
# Calculate new confidence (max 1.0)
|
||||
current_confidence: float = model.confidence # type: ignore[assignment]
|
||||
new_confidence = min(1.0, current_confidence + confidence_boost)
|
||||
|
||||
# Merge source episode IDs
|
||||
current_sources: list[UUID] = model.source_episode_ids or [] # type: ignore[assignment]
|
||||
if source_episode_ids:
|
||||
# Add new sources, avoiding duplicates
|
||||
new_sources = list(set(current_sources + source_episode_ids))
|
||||
else:
|
||||
new_sources = current_sources
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(FactModel)
|
||||
.where(FactModel.id == fact_id)
|
||||
.values(
|
||||
confidence=new_confidence,
|
||||
source_episode_ids=new_sources,
|
||||
last_reinforced=now,
|
||||
reinforcement_count=FactModel.reinforcement_count + 1,
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(FactModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"Reinforced fact {fact_id}: confidence {current_confidence:.2f} -> {new_confidence:.2f}"
|
||||
)
|
||||
|
||||
return _model_to_fact(updated_model)
|
||||
|
||||
async def deprecate_fact(
|
||||
self,
|
||||
fact_id: UUID,
|
||||
reason: str,
|
||||
new_confidence: float = 0.0,
|
||||
) -> Fact | None:
|
||||
"""
|
||||
Deprecate a fact by lowering its confidence.
|
||||
|
||||
Args:
|
||||
fact_id: Fact to deprecate
|
||||
reason: Reason for deprecation
|
||||
new_confidence: New confidence level (default 0.0)
|
||||
|
||||
Returns:
|
||||
Updated fact or None if not found
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(FactModel)
|
||||
.where(FactModel.id == fact_id)
|
||||
.values(
|
||||
confidence=max(0.0, new_confidence),
|
||||
updated_at=now,
|
||||
)
|
||||
.returning(FactModel)
|
||||
)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
updated_model = result.scalar_one_or_none()
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Deprecated fact {fact_id}: {reason}")
|
||||
|
||||
return _model_to_fact(updated_model) if updated_model else None
|
||||
|
||||
# =========================================================================
|
||||
# Fact Extraction from Episodes
|
||||
# =========================================================================
|
||||
|
||||
async def extract_facts_from_episode(
|
||||
self,
|
||||
episode: Episode,
|
||||
) -> list[Fact]:
|
||||
"""
|
||||
Extract facts from an episode.
|
||||
|
||||
This is a placeholder for LLM-based fact extraction.
|
||||
In production, this would call an LLM to analyze the episode
|
||||
and extract subject-predicate-object triples.
|
||||
|
||||
Args:
|
||||
episode: Episode to extract facts from
|
||||
|
||||
Returns:
|
||||
List of extracted facts
|
||||
"""
|
||||
# For now, extract basic facts from lessons learned
|
||||
extracted_facts: list[Fact] = []
|
||||
|
||||
for lesson in episode.lessons_learned:
|
||||
if len(lesson) > 10: # Skip very short lessons
|
||||
fact_create = FactCreate(
|
||||
subject=episode.task_type,
|
||||
predicate="lesson_learned",
|
||||
object=lesson,
|
||||
confidence=0.7, # Lessons start with moderate confidence
|
||||
project_id=episode.project_id,
|
||||
source_episode_ids=[episode.id],
|
||||
)
|
||||
fact = await self.store_fact(fact_create)
|
||||
extracted_facts.append(fact)
|
||||
|
||||
logger.debug(
|
||||
f"Extracted {len(extracted_facts)} facts from episode {episode.id}"
|
||||
)
|
||||
|
||||
return extracted_facts
|
||||
|
||||
# =========================================================================
|
||||
# Conflict Resolution
|
||||
# =========================================================================
|
||||
|
||||
async def resolve_conflict(
|
||||
self,
|
||||
fact_ids: list[UUID],
|
||||
keep_fact_id: UUID | None = None,
|
||||
) -> Fact | None:
|
||||
"""
|
||||
Resolve a conflict between multiple facts.
|
||||
|
||||
If keep_fact_id is specified, that fact is kept and others are deprecated.
|
||||
Otherwise, the fact with highest confidence is kept.
|
||||
|
||||
Args:
|
||||
fact_ids: IDs of conflicting facts
|
||||
keep_fact_id: Optional ID of fact to keep
|
||||
|
||||
Returns:
|
||||
The winning fact, or None if no facts found
|
||||
"""
|
||||
if not fact_ids:
|
||||
return None
|
||||
|
||||
# Load all facts
|
||||
query = select(FactModel).where(FactModel.id.in_(fact_ids))
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
if not models:
|
||||
return None
|
||||
|
||||
# Determine winner
|
||||
if keep_fact_id is not None:
|
||||
winner = next((m for m in models if m.id == keep_fact_id), None)
|
||||
if winner is None:
|
||||
# Fallback to highest confidence
|
||||
winner = max(models, key=lambda m: m.confidence)
|
||||
else:
|
||||
# Keep the fact with highest confidence
|
||||
winner = max(models, key=lambda m: m.confidence)
|
||||
|
||||
# Deprecate losers
|
||||
for model in models:
|
||||
if model.id != winner.id:
|
||||
await self.deprecate_fact(
|
||||
model.id, # type: ignore[arg-type]
|
||||
reason=f"Conflict resolution: superseded by {winner.id}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Resolved conflict between {len(fact_ids)} facts, keeping {winner.id}"
|
||||
)
|
||||
|
||||
return _model_to_fact(winner)
|
||||
|
||||
# =========================================================================
|
||||
# Confidence Decay
|
||||
# =========================================================================
|
||||
|
||||
async def apply_confidence_decay(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
decay_factor: float = 0.01,
|
||||
) -> int:
|
||||
"""
|
||||
Apply confidence decay to facts that haven't been reinforced recently.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to apply decay to
|
||||
decay_factor: Decay factor per day (default 0.01)
|
||||
|
||||
Returns:
|
||||
Number of facts affected
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
decay_days = self._settings.semantic_confidence_decay_days
|
||||
min_conf = self._settings.semantic_min_confidence
|
||||
|
||||
# Calculate cutoff date
|
||||
from datetime import timedelta
|
||||
|
||||
cutoff = now - timedelta(days=decay_days)
|
||||
|
||||
# Find facts needing decay
|
||||
query = select(FactModel).where(
|
||||
and_(
|
||||
FactModel.last_reinforced < cutoff,
|
||||
FactModel.confidence > min_conf,
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(FactModel.project_id == project_id)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Apply decay
|
||||
updated_count = 0
|
||||
for model in models:
|
||||
# Calculate days since last reinforcement
|
||||
days_since: float = (now - model.last_reinforced).days
|
||||
|
||||
# Calculate decay: exponential decay based on days
|
||||
decay = decay_factor * (days_since - decay_days)
|
||||
new_confidence = max(min_conf, model.confidence - decay)
|
||||
|
||||
if new_confidence != model.confidence:
|
||||
await self._session.execute(
|
||||
update(FactModel)
|
||||
.where(FactModel.id == model.id)
|
||||
.values(confidence=new_confidence, updated_at=now)
|
||||
)
|
||||
updated_count += 1
|
||||
|
||||
await self._session.flush()
|
||||
logger.info(f"Applied confidence decay to {updated_count} facts")
|
||||
|
||||
return updated_count
|
||||
|
||||
# =========================================================================
|
||||
# Statistics
|
||||
# =========================================================================
|
||||
|
||||
async def get_stats(self, project_id: UUID | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about semantic memory.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to get stats for
|
||||
|
||||
Returns:
|
||||
Dictionary with statistics
|
||||
"""
|
||||
# Get all facts for this scope
|
||||
query = select(FactModel)
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
if not models:
|
||||
return {
|
||||
"total_facts": 0,
|
||||
"avg_confidence": 0.0,
|
||||
"avg_reinforcement_count": 0.0,
|
||||
"high_confidence_count": 0,
|
||||
"low_confidence_count": 0,
|
||||
}
|
||||
|
||||
confidences = [m.confidence for m in models]
|
||||
reinforcements = [m.reinforcement_count for m in models]
|
||||
|
||||
return {
|
||||
"total_facts": len(models),
|
||||
"avg_confidence": sum(confidences) / len(confidences),
|
||||
"avg_reinforcement_count": sum(reinforcements) / len(reinforcements),
|
||||
"high_confidence_count": sum(1 for c in confidences if c >= 0.8),
|
||||
"low_confidence_count": sum(1 for c in confidences if c < 0.5),
|
||||
}
|
||||
|
||||
async def count(self, project_id: UUID | None = None) -> int:
|
||||
"""
|
||||
Count facts in scope.
|
||||
|
||||
Args:
|
||||
project_id: Optional project to count for
|
||||
|
||||
Returns:
|
||||
Number of facts
|
||||
"""
|
||||
query = select(FactModel)
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
return len(list(result.scalars().all()))
|
||||
|
||||
async def delete(self, fact_id: UUID) -> bool:
|
||||
"""
|
||||
Delete a fact.
|
||||
|
||||
Args:
|
||||
fact_id: Fact to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
await self._session.delete(model)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(f"Deleted fact {fact_id}")
|
||||
return True
|
||||
363
backend/app/services/memory/semantic/verification.py
Normal file
363
backend/app/services/memory/semantic/verification.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# app/services/memory/semantic/verification.py
|
||||
"""
|
||||
Fact Verification.
|
||||
|
||||
Provides utilities for verifying facts, detecting conflicts,
|
||||
and managing fact consistency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.memory.fact import Fact as FactModel
|
||||
from app.services.memory.types import Fact
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
"""Result of fact verification."""
|
||||
|
||||
is_valid: bool
|
||||
confidence_adjustment: float = 0.0
|
||||
conflicts: list["FactConflict"] = field(default_factory=list)
|
||||
supporting_facts: list[Fact] = field(default_factory=list)
|
||||
messages: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactConflict:
|
||||
"""Represents a conflict between two facts."""
|
||||
|
||||
fact_a_id: UUID
|
||||
fact_b_id: UUID
|
||||
conflict_type: str # "contradiction", "superseded", "partial_overlap"
|
||||
description: str
|
||||
suggested_resolution: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"fact_a_id": str(self.fact_a_id),
|
||||
"fact_b_id": str(self.fact_b_id),
|
||||
"conflict_type": self.conflict_type,
|
||||
"description": self.description,
|
||||
"suggested_resolution": self.suggested_resolution,
|
||||
}
|
||||
|
||||
|
||||
class FactVerifier:
|
||||
"""
|
||||
Verifies facts and detects conflicts.
|
||||
|
||||
Provides methods to:
|
||||
- Check if a fact conflicts with existing facts
|
||||
- Find supporting evidence for a fact
|
||||
- Detect contradictions in the fact base
|
||||
"""
|
||||
|
||||
# Predicates that are opposites/contradictions
|
||||
CONTRADICTORY_PREDICATES: ClassVar[set[tuple[str, str]]] = {
|
||||
("uses", "does_not_use"),
|
||||
("requires", "does_not_require"),
|
||||
("is_a", "is_not_a"),
|
||||
("causes", "prevents"),
|
||||
("allows", "prevents"),
|
||||
("supports", "does_not_support"),
|
||||
("best_practice", "anti_pattern"),
|
||||
}
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
"""Initialize verifier with database session."""
|
||||
self._session = session
|
||||
|
||||
async def verify_fact(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
project_id: UUID | None = None,
|
||||
) -> VerificationResult:
|
||||
"""
|
||||
Verify a fact against existing facts.
|
||||
|
||||
Args:
|
||||
subject: Fact subject
|
||||
predicate: Fact predicate
|
||||
obj: Fact object
|
||||
project_id: Optional project scope
|
||||
|
||||
Returns:
|
||||
VerificationResult with verification details
|
||||
"""
|
||||
result = VerificationResult(is_valid=True)
|
||||
|
||||
# Check for direct contradictions
|
||||
conflicts = await self._find_contradictions(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
obj=obj,
|
||||
project_id=project_id,
|
||||
)
|
||||
result.conflicts = conflicts
|
||||
|
||||
if conflicts:
|
||||
result.is_valid = False
|
||||
result.messages.append(f"Found {len(conflicts)} conflicting fact(s)")
|
||||
# Reduce confidence based on conflicts
|
||||
result.confidence_adjustment = -0.1 * len(conflicts)
|
||||
|
||||
# Find supporting facts
|
||||
supporting = await self._find_supporting_facts(
|
||||
subject=subject,
|
||||
predicate=predicate,
|
||||
project_id=project_id,
|
||||
)
|
||||
result.supporting_facts = supporting
|
||||
|
||||
if supporting:
|
||||
result.messages.append(f"Found {len(supporting)} supporting fact(s)")
|
||||
# Boost confidence based on support
|
||||
result.confidence_adjustment += 0.05 * min(len(supporting), 3)
|
||||
|
||||
return result
|
||||
|
||||
async def _find_contradictions(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
obj: str,
|
||||
project_id: UUID | None = None,
|
||||
) -> list[FactConflict]:
|
||||
"""Find facts that contradict the given fact."""
|
||||
conflicts: list[FactConflict] = []
|
||||
|
||||
# Find opposite predicates
|
||||
opposite_predicates = self._get_opposite_predicates(predicate)
|
||||
|
||||
if not opposite_predicates:
|
||||
return conflicts
|
||||
|
||||
# Search for contradicting facts
|
||||
query = select(FactModel).where(
|
||||
and_(
|
||||
FactModel.subject == subject,
|
||||
FactModel.predicate.in_(opposite_predicates),
|
||||
)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
for model in models:
|
||||
conflicts.append(
|
||||
FactConflict(
|
||||
fact_a_id=model.id, # type: ignore[arg-type]
|
||||
fact_b_id=UUID(
|
||||
"00000000-0000-0000-0000-000000000000"
|
||||
), # Placeholder for new fact
|
||||
conflict_type="contradiction",
|
||||
description=(
|
||||
f"'{subject} {predicate} {obj}' contradicts "
|
||||
f"'{model.subject} {model.predicate} {model.object}'"
|
||||
),
|
||||
suggested_resolution="Keep fact with higher confidence",
|
||||
)
|
||||
)
|
||||
|
||||
return conflicts
|
||||
|
||||
def _get_opposite_predicates(self, predicate: str) -> list[str]:
|
||||
"""Get predicates that are opposite to the given predicate."""
|
||||
opposites: list[str] = []
|
||||
|
||||
for pair in self.CONTRADICTORY_PREDICATES:
|
||||
if predicate in pair:
|
||||
opposites.extend(p for p in pair if p != predicate)
|
||||
|
||||
return opposites
|
||||
|
||||
async def _find_supporting_facts(
|
||||
self,
|
||||
subject: str,
|
||||
predicate: str,
|
||||
project_id: UUID | None = None,
|
||||
) -> list[Fact]:
|
||||
"""Find facts that support the given fact."""
|
||||
# Find facts with same subject and predicate
|
||||
query = (
|
||||
select(FactModel)
|
||||
.where(
|
||||
and_(
|
||||
FactModel.subject == subject,
|
||||
FactModel.predicate == predicate,
|
||||
FactModel.confidence >= 0.5,
|
||||
)
|
||||
)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
return [self._model_to_fact(m) for m in models]
|
||||
|
||||
async def find_all_conflicts(
|
||||
self,
|
||||
project_id: UUID | None = None,
|
||||
) -> list[FactConflict]:
|
||||
"""
|
||||
Find all conflicts in the fact base.
|
||||
|
||||
Args:
|
||||
project_id: Optional project scope
|
||||
|
||||
Returns:
|
||||
List of all detected conflicts
|
||||
"""
|
||||
conflicts: list[FactConflict] = []
|
||||
|
||||
# Get all facts
|
||||
query = select(FactModel)
|
||||
if project_id is not None:
|
||||
query = query.where(
|
||||
or_(
|
||||
FactModel.project_id == project_id,
|
||||
FactModel.project_id.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
result = await self._session.execute(query)
|
||||
models = list(result.scalars().all())
|
||||
|
||||
# Check each pair for conflicts
|
||||
for i, fact_a in enumerate(models):
|
||||
for fact_b in models[i + 1 :]:
|
||||
conflict = self._check_pair_conflict(fact_a, fact_b)
|
||||
if conflict:
|
||||
conflicts.append(conflict)
|
||||
|
||||
logger.info(f"Found {len(conflicts)} conflicts in fact base")
|
||||
|
||||
return conflicts
|
||||
|
||||
def _check_pair_conflict(
|
||||
self,
|
||||
fact_a: FactModel,
|
||||
fact_b: FactModel,
|
||||
) -> FactConflict | None:
|
||||
"""Check if two facts conflict."""
|
||||
# Same subject?
|
||||
if fact_a.subject != fact_b.subject:
|
||||
return None
|
||||
|
||||
# Contradictory predicates?
|
||||
opposite = self._get_opposite_predicates(fact_a.predicate) # type: ignore[arg-type]
|
||||
if fact_b.predicate not in opposite:
|
||||
return None
|
||||
|
||||
return FactConflict(
|
||||
fact_a_id=fact_a.id, # type: ignore[arg-type]
|
||||
fact_b_id=fact_b.id, # type: ignore[arg-type]
|
||||
conflict_type="contradiction",
|
||||
description=(
|
||||
f"'{fact_a.subject} {fact_a.predicate} {fact_a.object}' "
|
||||
f"contradicts '{fact_b.subject} {fact_b.predicate} {fact_b.object}'"
|
||||
),
|
||||
suggested_resolution="Deprecate fact with lower confidence",
|
||||
)
|
||||
|
||||
async def get_fact_reliability_score(
|
||||
self,
|
||||
fact_id: UUID,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate a reliability score for a fact.
|
||||
|
||||
Based on:
|
||||
- Confidence score
|
||||
- Number of reinforcements
|
||||
- Number of supporting facts
|
||||
- Absence of conflicts
|
||||
|
||||
Args:
|
||||
fact_id: Fact to score
|
||||
|
||||
Returns:
|
||||
Reliability score (0.0 to 1.0)
|
||||
"""
|
||||
query = select(FactModel).where(FactModel.id == fact_id)
|
||||
result = await self._session.execute(query)
|
||||
model = result.scalar_one_or_none()
|
||||
|
||||
if model is None:
|
||||
return 0.0
|
||||
|
||||
# Base score from confidence - explicitly typed to avoid Column type issues
|
||||
score: float = float(model.confidence)
|
||||
|
||||
# Boost for reinforcements (diminishing returns)
|
||||
reinforcement_boost = min(0.2, float(model.reinforcement_count) * 0.02)
|
||||
score += reinforcement_boost
|
||||
|
||||
# Find supporting facts
|
||||
supporting = await self._find_supporting_facts(
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
)
|
||||
support_boost = min(0.1, len(supporting) * 0.02)
|
||||
score += support_boost
|
||||
|
||||
# Check for conflicts
|
||||
conflicts = await self._find_contradictions(
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
obj=model.object, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
)
|
||||
conflict_penalty = min(0.3, len(conflicts) * 0.1)
|
||||
score -= conflict_penalty
|
||||
|
||||
# Clamp to valid range
|
||||
return max(0.0, min(1.0, score))
|
||||
|
||||
def _model_to_fact(self, model: FactModel) -> Fact:
|
||||
"""Convert SQLAlchemy model to Fact dataclass."""
|
||||
return Fact(
|
||||
id=model.id, # type: ignore[arg-type]
|
||||
project_id=model.project_id, # type: ignore[arg-type]
|
||||
subject=model.subject, # type: ignore[arg-type]
|
||||
predicate=model.predicate, # type: ignore[arg-type]
|
||||
object=model.object, # type: ignore[arg-type]
|
||||
confidence=model.confidence, # type: ignore[arg-type]
|
||||
source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type]
|
||||
first_learned=model.first_learned, # type: ignore[arg-type]
|
||||
last_reinforced=model.last_reinforced, # type: ignore[arg-type]
|
||||
reinforcement_count=model.reinforcement_count, # type: ignore[arg-type]
|
||||
embedding=None,
|
||||
created_at=model.created_at, # type: ignore[arg-type]
|
||||
updated_at=model.updated_at, # type: ignore[arg-type]
|
||||
)
|
||||
2
backend/tests/unit/services/memory/semantic/__init__.py
Normal file
2
backend/tests/unit/services/memory/semantic/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/unit/services/memory/semantic/__init__.py
|
||||
"""Unit tests for semantic memory service."""
|
||||
263
backend/tests/unit/services/memory/semantic/test_extraction.py
Normal file
263
backend/tests/unit/services/memory/semantic/test_extraction.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# tests/unit/services/memory/semantic/test_extraction.py
|
||||
"""Unit tests for fact extraction."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.semantic.extraction import (
|
||||
ExtractedFact,
|
||||
ExtractionContext,
|
||||
FactExtractor,
|
||||
get_fact_extractor,
|
||||
)
|
||||
from app.services.memory.types import Episode, Outcome
|
||||
|
||||
|
||||
def create_test_episode(
|
||||
lessons_learned: list[str] | None = None,
|
||||
outcome: Outcome = Outcome.SUCCESS,
|
||||
task_type: str = "code_review",
|
||||
task_description: str = "Review the authentication module",
|
||||
outcome_details: str = "",
|
||||
) -> Episode:
|
||||
"""Create a test episode for extraction tests."""
|
||||
return Episode(
|
||||
id=uuid4(),
|
||||
project_id=uuid4(),
|
||||
agent_instance_id=None,
|
||||
agent_type_id=None,
|
||||
session_id="test-session",
|
||||
task_type=task_type,
|
||||
task_description=task_description,
|
||||
actions=[],
|
||||
context_summary="Test context",
|
||||
outcome=outcome,
|
||||
outcome_details=outcome_details,
|
||||
duration_seconds=60.0,
|
||||
tokens_used=500,
|
||||
lessons_learned=lessons_learned or [],
|
||||
importance_score=0.7,
|
||||
embedding=None,
|
||||
occurred_at=datetime.now(UTC),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
|
||||
class TestExtractedFact:
|
||||
"""Tests for ExtractedFact dataclass."""
|
||||
|
||||
def test_to_fact_create(self) -> None:
|
||||
"""Test converting ExtractedFact to FactCreate."""
|
||||
extracted = ExtractedFact(
|
||||
subject="Python",
|
||||
predicate="uses",
|
||||
object="dynamic typing",
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
fact_create = extracted.to_fact_create(
|
||||
project_id=uuid4(),
|
||||
source_episode_ids=[uuid4()],
|
||||
)
|
||||
|
||||
assert fact_create.subject == "Python"
|
||||
assert fact_create.predicate == "uses"
|
||||
assert fact_create.object == "dynamic typing"
|
||||
assert fact_create.confidence == 0.8
|
||||
|
||||
def test_to_fact_create_defaults(self) -> None:
|
||||
"""Test to_fact_create with default values."""
|
||||
extracted = ExtractedFact(
|
||||
subject="A",
|
||||
predicate="B",
|
||||
object="C",
|
||||
confidence=0.5,
|
||||
)
|
||||
|
||||
fact_create = extracted.to_fact_create()
|
||||
|
||||
assert fact_create.project_id is None
|
||||
assert fact_create.source_episode_ids == []
|
||||
|
||||
|
||||
class TestFactExtractor:
|
||||
"""Tests for FactExtractor class."""
|
||||
|
||||
@pytest.fixture
|
||||
def extractor(self) -> FactExtractor:
|
||||
"""Create a fact extractor."""
|
||||
return FactExtractor()
|
||||
|
||||
def test_extract_from_episode_with_lessons(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting facts from episode with lessons."""
|
||||
episode = create_test_episode(
|
||||
lessons_learned=[
|
||||
"Always validate user input before processing",
|
||||
"Use parameterized queries to prevent SQL injection",
|
||||
]
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
assert len(facts) > 0
|
||||
# Should have lesson_learned predicates
|
||||
lesson_facts = [f for f in facts if f.predicate == "lesson_learned"]
|
||||
assert len(lesson_facts) >= 2
|
||||
|
||||
def test_extract_from_episode_with_always_pattern(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting 'always' pattern lessons."""
|
||||
episode = create_test_episode(
|
||||
lessons_learned=["Always close file handles properly"]
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
best_practices = [f for f in facts if f.predicate == "best_practice"]
|
||||
assert len(best_practices) >= 1
|
||||
assert any("close file handles" in f.object for f in best_practices)
|
||||
|
||||
def test_extract_from_episode_with_never_pattern(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting 'never' pattern lessons."""
|
||||
episode = create_test_episode(
|
||||
lessons_learned=["Never store passwords in plain text"]
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
anti_patterns = [f for f in facts if f.predicate == "anti_pattern"]
|
||||
assert len(anti_patterns) >= 1
|
||||
|
||||
def test_extract_from_episode_with_conditional_pattern(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting conditional lessons."""
|
||||
episode = create_test_episode(
|
||||
lessons_learned=["When handling errors, log the stack trace"]
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
conditional = [f for f in facts if f.predicate == "requires_action"]
|
||||
assert len(conditional) >= 1
|
||||
|
||||
def test_extract_outcome_facts_success(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting facts from successful episode."""
|
||||
episode = create_test_episode(
|
||||
outcome=Outcome.SUCCESS,
|
||||
outcome_details="Deployed to production without issues",
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
success_facts = [f for f in facts if f.predicate == "successful_approach"]
|
||||
assert len(success_facts) >= 1
|
||||
|
||||
def test_extract_outcome_facts_failure(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting facts from failed episode."""
|
||||
episode = create_test_episode(
|
||||
outcome=Outcome.FAILURE,
|
||||
outcome_details="Connection timeout during deployment",
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode)
|
||||
|
||||
failure_facts = [f for f in facts if f.predicate == "known_failure_mode"]
|
||||
assert len(failure_facts) >= 1
|
||||
|
||||
def test_extract_from_text_uses_pattern(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting 'uses' pattern from text."""
|
||||
text = "FastAPI uses Starlette for ASGI support."
|
||||
|
||||
facts = extractor.extract_from_text(text)
|
||||
|
||||
assert len(facts) >= 1
|
||||
uses_facts = [f for f in facts if f.predicate == "uses"]
|
||||
assert len(uses_facts) >= 1
|
||||
|
||||
def test_extract_from_text_requires_pattern(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting 'requires' pattern from text."""
|
||||
text = "This feature requires Python 3.10 or higher."
|
||||
|
||||
facts = extractor.extract_from_text(text)
|
||||
|
||||
requires_facts = [f for f in facts if f.predicate == "requires"]
|
||||
assert len(requires_facts) >= 1
|
||||
|
||||
def test_extract_from_text_empty(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting from empty text."""
|
||||
facts = extractor.extract_from_text("")
|
||||
|
||||
assert facts == []
|
||||
|
||||
def test_extract_from_text_short(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extracting from too-short text."""
|
||||
facts = extractor.extract_from_text("Hi.")
|
||||
|
||||
assert facts == []
|
||||
|
||||
def test_extract_with_context(
|
||||
self,
|
||||
extractor: FactExtractor,
|
||||
) -> None:
|
||||
"""Test extraction with custom context."""
|
||||
episode = create_test_episode(lessons_learned=["Low confidence lesson"])
|
||||
|
||||
context = ExtractionContext(
|
||||
min_confidence=0.9, # High threshold
|
||||
max_facts_per_source=2,
|
||||
)
|
||||
|
||||
facts = extractor.extract_from_episode(episode, context)
|
||||
|
||||
# Should filter out low confidence facts
|
||||
for fact in facts:
|
||||
assert fact.confidence >= 0.9 or len(facts) <= 2
|
||||
|
||||
|
||||
class TestGetFactExtractor:
|
||||
"""Tests for singleton getter."""
|
||||
|
||||
def test_get_fact_extractor_returns_instance(self) -> None:
|
||||
"""Test that get_fact_extractor returns an instance."""
|
||||
extractor = get_fact_extractor()
|
||||
|
||||
assert extractor is not None
|
||||
assert isinstance(extractor, FactExtractor)
|
||||
|
||||
def test_get_fact_extractor_returns_same_instance(self) -> None:
|
||||
"""Test that get_fact_extractor returns singleton."""
|
||||
extractor1 = get_fact_extractor()
|
||||
extractor2 = get_fact_extractor()
|
||||
|
||||
assert extractor1 is extractor2
|
||||
446
backend/tests/unit/services/memory/semantic/test_memory.py
Normal file
446
backend/tests/unit/services/memory/semantic/test_memory.py
Normal file
@@ -0,0 +1,446 @@
|
||||
# tests/unit/services/memory/semantic/test_memory.py
|
||||
"""Unit tests for SemanticMemory class."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.semantic.memory import SemanticMemory
|
||||
from app.services.memory.types import FactCreate
|
||||
|
||||
|
||||
def create_mock_fact_model(
|
||||
project_id=None,
|
||||
subject="FastAPI",
|
||||
predicate="uses",
|
||||
obj="Starlette",
|
||||
confidence=0.8,
|
||||
):
|
||||
"""Create a mock fact model for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = uuid4()
|
||||
mock.project_id = project_id
|
||||
mock.subject = subject
|
||||
mock.predicate = predicate
|
||||
mock.object = obj
|
||||
mock.confidence = confidence
|
||||
mock.source_episode_ids = []
|
||||
mock.first_learned = datetime.now(UTC)
|
||||
mock.last_reinforced = datetime.now(UTC)
|
||||
mock.reinforcement_count = 1
|
||||
mock.embedding = None
|
||||
mock.created_at = datetime.now(UTC)
|
||||
mock.updated_at = datetime.now(UTC)
|
||||
return mock
|
||||
|
||||
|
||||
class TestSemanticMemoryInit:
|
||||
"""Tests for SemanticMemory initialization."""
|
||||
|
||||
def test_init_creates_memory(self) -> None:
|
||||
"""Test that init creates memory instance."""
|
||||
mock_session = AsyncMock()
|
||||
memory = SemanticMemory(session=mock_session)
|
||||
|
||||
assert memory._session is mock_session
|
||||
|
||||
def test_init_with_embedding_generator(self) -> None:
|
||||
"""Test init with embedding generator."""
|
||||
mock_session = AsyncMock()
|
||||
mock_embedding_gen = AsyncMock()
|
||||
memory = SemanticMemory(
|
||||
session=mock_session, embedding_generator=mock_embedding_gen
|
||||
)
|
||||
|
||||
assert memory._embedding_generator is mock_embedding_gen
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_factory_method(self) -> None:
|
||||
"""Test create factory method."""
|
||||
mock_session = AsyncMock()
|
||||
memory = await SemanticMemory.create(session=mock_session)
|
||||
|
||||
assert memory is not None
|
||||
assert memory._session is mock_session
|
||||
|
||||
|
||||
class TestSemanticMemoryStoreFact:
|
||||
"""Tests for fact storage methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.flush = AsyncMock()
|
||||
session.refresh = AsyncMock()
|
||||
# Mock no existing fact found
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
|
||||
"""Create a SemanticMemory instance."""
|
||||
return SemanticMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_new_fact(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test storing a new fact."""
|
||||
fact_data = FactCreate(
|
||||
subject="Python",
|
||||
predicate="is_a",
|
||||
object="programming language",
|
||||
confidence=0.9,
|
||||
project_id=uuid4(),
|
||||
)
|
||||
|
||||
result = await memory.store_fact(fact_data)
|
||||
|
||||
assert result.subject == "Python"
|
||||
assert result.predicate == "is_a"
|
||||
assert result.object == "programming language"
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_fact_reinforces_existing(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that storing duplicate fact reinforces existing."""
|
||||
# Mock existing fact found - needs to be found first
|
||||
existing_mock = create_mock_fact_model(confidence=0.7)
|
||||
find_result = MagicMock()
|
||||
find_result.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
# Second find for reinforce_fact
|
||||
find_for_reinforce = MagicMock()
|
||||
find_for_reinforce.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
# Mock update result - returns the updated mock
|
||||
updated_mock = create_mock_fact_model(confidence=0.8)
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one.return_value = updated_mock
|
||||
|
||||
mock_session.execute.side_effect = [
|
||||
find_result, # _find_existing_fact
|
||||
find_for_reinforce, # reinforce_fact query
|
||||
update_result, # reinforce_fact update
|
||||
]
|
||||
|
||||
fact_data = FactCreate(
|
||||
subject="FastAPI",
|
||||
predicate="uses",
|
||||
object="Starlette",
|
||||
)
|
||||
|
||||
_ = await memory.store_fact(fact_data)
|
||||
|
||||
# Should have called execute three times (find + find + update)
|
||||
assert mock_session.execute.call_count == 3
|
||||
|
||||
|
||||
class TestSemanticMemorySearch:
|
||||
"""Tests for fact search methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
|
||||
"""Create a SemanticMemory instance."""
|
||||
return SemanticMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_facts(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
) -> None:
|
||||
"""Test searching for facts."""
|
||||
results = await memory.search_facts("Python programming")
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_facts_with_project_filter(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
) -> None:
|
||||
"""Test searching for facts with project filter."""
|
||||
project_id = uuid4()
|
||||
results = await memory.search_facts("Python", project_id=project_id)
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_entity(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
) -> None:
|
||||
"""Test getting facts by entity."""
|
||||
results = await memory.get_by_entity("FastAPI")
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_subject(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
) -> None:
|
||||
"""Test getting facts by subject."""
|
||||
results = await memory.get_by_subject("Python")
|
||||
|
||||
assert isinstance(results, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_id_not_found(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test get_by_id returns None when not found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.get_by_id(uuid4())
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSemanticMemoryReinforcement:
|
||||
"""Tests for fact reinforcement."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
|
||||
"""Create a SemanticMemory instance."""
|
||||
return SemanticMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reinforce_fact(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test reinforcing a fact."""
|
||||
existing_mock = create_mock_fact_model(confidence=0.7)
|
||||
|
||||
# First query: find fact
|
||||
find_result = MagicMock()
|
||||
find_result.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
# Second query: update fact
|
||||
updated_mock = create_mock_fact_model(confidence=0.8)
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one.return_value = updated_mock
|
||||
|
||||
mock_session.execute.side_effect = [find_result, update_result]
|
||||
|
||||
result = await memory.reinforce_fact(existing_mock.id, confidence_boost=0.1)
|
||||
|
||||
assert result.confidence == 0.8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reinforce_fact_not_found(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test reinforcing a non-existent fact raises error."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
with pytest.raises(ValueError, match="Fact not found"):
|
||||
await memory.reinforce_fact(uuid4())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprecate_fact(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test deprecating a fact."""
|
||||
existing_mock = create_mock_fact_model(confidence=0.8)
|
||||
|
||||
find_result = MagicMock()
|
||||
find_result.scalar_one_or_none.return_value = existing_mock
|
||||
|
||||
deprecated_mock = create_mock_fact_model(confidence=0.0)
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one_or_none.return_value = deprecated_mock
|
||||
|
||||
mock_session.execute.side_effect = [find_result, update_result]
|
||||
|
||||
result = await memory.deprecate_fact(existing_mock.id, reason="Outdated")
|
||||
|
||||
assert result is not None
|
||||
assert result.confidence == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deprecate_fact_not_found(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test deprecating non-existent fact returns None."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.deprecate_fact(uuid4(), reason="Test")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestSemanticMemoryConflictResolution:
|
||||
"""Tests for conflict resolution."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
|
||||
"""Create a SemanticMemory instance."""
|
||||
return SemanticMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_conflict_empty_list(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
) -> None:
|
||||
"""Test resolving conflict with empty list."""
|
||||
result = await memory.resolve_conflict([])
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_conflict_keeps_highest_confidence(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test that conflict resolution keeps highest confidence fact."""
|
||||
fact_low = create_mock_fact_model(confidence=0.5)
|
||||
fact_high = create_mock_fact_model(confidence=0.9)
|
||||
|
||||
# Mock finding the facts
|
||||
find_result = MagicMock()
|
||||
find_result.scalars.return_value.all.return_value = [fact_low, fact_high]
|
||||
|
||||
# Mock deprecation (find + update)
|
||||
find_one_result = MagicMock()
|
||||
find_one_result.scalar_one_or_none.return_value = fact_low
|
||||
update_result = MagicMock()
|
||||
update_result.scalar_one_or_none.return_value = fact_low
|
||||
|
||||
mock_session.execute.side_effect = [find_result, find_one_result, update_result]
|
||||
|
||||
result = await memory.resolve_conflict([fact_low.id, fact_high.id])
|
||||
|
||||
assert result is not None
|
||||
assert result.confidence == 0.9
|
||||
|
||||
|
||||
class TestSemanticMemoryStats:
|
||||
"""Tests for statistics methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
|
||||
"""Create a SemanticMemory instance."""
|
||||
return SemanticMemory(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_stats_empty(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test getting stats for empty project."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
stats = await memory.get_stats(uuid4())
|
||||
|
||||
assert stats["total_facts"] == 0
|
||||
assert stats["avg_confidence"] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test counting facts."""
|
||||
facts = [create_mock_fact_model() for _ in range(5)]
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = facts
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
count = await memory.count(uuid4())
|
||||
|
||||
assert count == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test deleting a fact."""
|
||||
existing_mock = create_mock_fact_model()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = existing_mock
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session.delete = AsyncMock()
|
||||
|
||||
result = await memory.delete(existing_mock.id)
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_not_found(
|
||||
self,
|
||||
memory: SemanticMemory,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test deleting non-existent fact."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = await memory.delete(uuid4())
|
||||
|
||||
assert result is False
|
||||
298
backend/tests/unit/services/memory/semantic/test_verification.py
Normal file
298
backend/tests/unit/services/memory/semantic/test_verification.py
Normal file
@@ -0,0 +1,298 @@
|
||||
# tests/unit/services/memory/semantic/test_verification.py
|
||||
"""Unit tests for fact verification."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.memory.semantic.verification import (
|
||||
FactConflict,
|
||||
FactVerifier,
|
||||
VerificationResult,
|
||||
)
|
||||
|
||||
|
||||
def create_mock_fact_model(
|
||||
subject="FastAPI",
|
||||
predicate="uses",
|
||||
obj="Starlette",
|
||||
confidence=0.8,
|
||||
project_id=None,
|
||||
):
|
||||
"""Create a mock fact model for testing."""
|
||||
mock = MagicMock()
|
||||
mock.id = uuid4()
|
||||
mock.project_id = project_id
|
||||
mock.subject = subject
|
||||
mock.predicate = predicate
|
||||
mock.object = obj
|
||||
mock.confidence = confidence
|
||||
mock.source_episode_ids = []
|
||||
mock.first_learned = datetime.now(UTC)
|
||||
mock.last_reinforced = datetime.now(UTC)
|
||||
mock.reinforcement_count = 1
|
||||
mock.embedding = None
|
||||
mock.created_at = datetime.now(UTC)
|
||||
mock.updated_at = datetime.now(UTC)
|
||||
return mock
|
||||
|
||||
|
||||
class TestFactConflict:
|
||||
"""Tests for FactConflict dataclass."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test converting conflict to dictionary."""
|
||||
conflict = FactConflict(
|
||||
fact_a_id=uuid4(),
|
||||
fact_b_id=uuid4(),
|
||||
conflict_type="contradiction",
|
||||
description="Test conflict",
|
||||
suggested_resolution="Keep higher confidence",
|
||||
)
|
||||
|
||||
result = conflict.to_dict()
|
||||
|
||||
assert "fact_a_id" in result
|
||||
assert "fact_b_id" in result
|
||||
assert result["conflict_type"] == "contradiction"
|
||||
assert result["description"] == "Test conflict"
|
||||
|
||||
|
||||
class TestVerificationResult:
|
||||
"""Tests for VerificationResult dataclass."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default values."""
|
||||
result = VerificationResult(is_valid=True)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.confidence_adjustment == 0.0
|
||||
assert result.conflicts == []
|
||||
assert result.supporting_facts == []
|
||||
assert result.messages == []
|
||||
|
||||
|
||||
class TestFactVerifier:
|
||||
"""Tests for FactVerifier class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self) -> AsyncMock:
|
||||
"""Create a mock database session."""
|
||||
session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
session.execute.return_value = mock_result
|
||||
return session
|
||||
|
||||
@pytest.fixture
|
||||
def verifier(self, mock_session: AsyncMock) -> FactVerifier:
|
||||
"""Create a fact verifier."""
|
||||
return FactVerifier(session=mock_session)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_fact_valid(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
) -> None:
|
||||
"""Test verifying a valid fact with no conflicts."""
|
||||
result = await verifier.verify_fact(
|
||||
subject="Python",
|
||||
predicate="is_a",
|
||||
obj="programming language",
|
||||
)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert len(result.conflicts) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_fact_with_support(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test verifying a fact with supporting facts."""
|
||||
# Mock finding supporting facts
|
||||
supporting = [create_mock_fact_model()]
|
||||
|
||||
# First query: contradictions (empty)
|
||||
contradiction_result = MagicMock()
|
||||
contradiction_result.scalars.return_value.all.return_value = []
|
||||
|
||||
# Second query: supporting facts
|
||||
support_result = MagicMock()
|
||||
support_result.scalars.return_value.all.return_value = supporting
|
||||
|
||||
mock_session.execute.side_effect = [contradiction_result, support_result]
|
||||
|
||||
result = await verifier.verify_fact(
|
||||
subject="Python",
|
||||
predicate="uses",
|
||||
obj="dynamic typing",
|
||||
)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert len(result.supporting_facts) >= 1
|
||||
assert result.confidence_adjustment > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_fact_with_contradiction(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test verifying a fact with contradictions."""
|
||||
# Mock finding contradicting fact
|
||||
contradicting = create_mock_fact_model(
|
||||
subject="Python",
|
||||
predicate="does_not_use",
|
||||
obj="static typing",
|
||||
)
|
||||
|
||||
contradiction_result = MagicMock()
|
||||
contradiction_result.scalars.return_value.all.return_value = [contradicting]
|
||||
|
||||
support_result = MagicMock()
|
||||
support_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session.execute.side_effect = [contradiction_result, support_result]
|
||||
|
||||
result = await verifier.verify_fact(
|
||||
subject="Python",
|
||||
predicate="uses",
|
||||
obj="static typing",
|
||||
)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert len(result.conflicts) >= 1
|
||||
assert result.confidence_adjustment < 0
|
||||
|
||||
def test_get_opposite_predicates(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
) -> None:
|
||||
"""Test getting opposite predicates."""
|
||||
opposites = verifier._get_opposite_predicates("uses")
|
||||
|
||||
assert "does_not_use" in opposites
|
||||
|
||||
def test_get_opposite_predicates_unknown(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
) -> None:
|
||||
"""Test getting opposites for unknown predicate."""
|
||||
opposites = verifier._get_opposite_predicates("unknown_predicate")
|
||||
|
||||
assert opposites == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_all_conflicts_empty(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test finding all conflicts in empty fact base."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
conflicts = await verifier.find_all_conflicts()
|
||||
|
||||
assert conflicts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_all_conflicts_no_conflicts(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test finding conflicts when there are none."""
|
||||
# Two facts with different subjects
|
||||
fact1 = create_mock_fact_model(subject="Python", predicate="uses")
|
||||
fact2 = create_mock_fact_model(subject="JavaScript", predicate="uses")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [fact1, fact2]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
conflicts = await verifier.find_all_conflicts()
|
||||
|
||||
assert conflicts == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_all_conflicts_with_contradiction(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test finding contradicting facts."""
|
||||
# Two contradicting facts
|
||||
fact1 = create_mock_fact_model(
|
||||
subject="Python",
|
||||
predicate="best_practice",
|
||||
obj="Use type hints",
|
||||
)
|
||||
fact2 = create_mock_fact_model(
|
||||
subject="Python",
|
||||
predicate="anti_pattern",
|
||||
obj="Use type hints",
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [fact1, fact2]
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
conflicts = await verifier.find_all_conflicts()
|
||||
|
||||
assert len(conflicts) == 1
|
||||
assert conflicts[0].conflict_type == "contradiction"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_fact_reliability_score_not_found(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test reliability score for non-existent fact."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
score = await verifier.get_fact_reliability_score(uuid4())
|
||||
|
||||
assert score == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_fact_reliability_score(
|
||||
self,
|
||||
verifier: FactVerifier,
|
||||
mock_session: AsyncMock,
|
||||
) -> None:
|
||||
"""Test calculating reliability score."""
|
||||
fact = create_mock_fact_model(confidence=0.8)
|
||||
fact.reinforcement_count = 5
|
||||
|
||||
# Query 1: Get fact
|
||||
fact_result = MagicMock()
|
||||
fact_result.scalar_one_or_none.return_value = fact
|
||||
|
||||
# Query 2: Supporting facts
|
||||
support_result = MagicMock()
|
||||
support_result.scalars.return_value.all.return_value = []
|
||||
|
||||
# Query 3: Contradictions
|
||||
conflict_result = MagicMock()
|
||||
conflict_result.scalars.return_value.all.return_value = []
|
||||
|
||||
mock_session.execute.side_effect = [
|
||||
fact_result,
|
||||
support_result,
|
||||
conflict_result,
|
||||
]
|
||||
|
||||
score = await verifier.get_fact_reliability_score(fact.id)
|
||||
|
||||
# Score should be >= confidence (0.8) due to reinforcement bonus
|
||||
assert score >= 0.8
|
||||
assert score <= 1.0
|
||||
Reference in New Issue
Block a user