feat(memory): add semantic memory implementation (Issue #91)

Implements semantic memory with fact storage, retrieval, and verification:

Core functionality:
- SemanticMemory class for fact storage/retrieval
- Fact storage as subject-predicate-object triples
- Duplicate detection with reinforcement
- Semantic search with text-based fallback
- Entity-based retrieval
- Confidence scoring and decay
- Conflict resolution

Supporting modules:
- FactExtractor: Pattern-based fact extraction from episodes
- FactVerifier: Contradiction detection and reliability scoring

Test coverage:
- 47 unit tests covering all modules
- extraction.py: 99% coverage
- verification.py: 95% coverage
- memory.py: 78% coverage

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-05 02:23:06 +01:00
parent 3554efe66a
commit e946787a61
8 changed files with 2447 additions and 1 deletions

View File

@@ -1,3 +1,4 @@
# app/services/memory/semantic/__init__.py
"""
Semantic Memory
@@ -5,4 +6,22 @@ Fact storage with triple format (subject, predicate, object)
and semantic search capabilities.
"""
# Will be populated in #91
from .extraction import (
ExtractedFact,
ExtractionContext,
FactExtractor,
get_fact_extractor,
)
from .memory import SemanticMemory
from .verification import FactConflict, FactVerifier, VerificationResult
__all__ = [
"ExtractedFact",
"ExtractionContext",
"FactConflict",
"FactExtractor",
"FactVerifier",
"SemanticMemory",
"VerificationResult",
"get_fact_extractor",
]

View File

@@ -0,0 +1,313 @@
# app/services/memory/semantic/extraction.py
"""
Fact Extraction from Episodes.
Provides utilities for extracting semantic facts (subject-predicate-object triples)
from episodic memories and other text sources.
"""
import logging
import re
from dataclasses import dataclass, field
from typing import Any, ClassVar
from app.services.memory.types import Episode, FactCreate, Outcome
logger = logging.getLogger(__name__)
@dataclass
class ExtractionContext:
"""Context for fact extraction."""
project_id: Any | None = None
source_episode_id: Any | None = None
min_confidence: float = 0.5
max_facts_per_source: int = 10
@dataclass
class ExtractedFact:
"""A fact extracted from text before storage."""
subject: str
predicate: str
object: str
confidence: float
source_text: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
def to_fact_create(
self,
project_id: Any | None = None,
source_episode_ids: list[Any] | None = None,
) -> FactCreate:
"""Convert to FactCreate for storage."""
return FactCreate(
subject=self.subject,
predicate=self.predicate,
object=self.object,
confidence=self.confidence,
project_id=project_id,
source_episode_ids=source_episode_ids or [],
)
class FactExtractor:
"""
Extracts facts from episodes and text.
This is a rule-based extractor. In production, this would be
replaced or augmented with LLM-based extraction for better accuracy.
"""
# Common predicates we can detect
PREDICATE_PATTERNS: ClassVar[dict[str, str]] = {
"uses": r"(?:uses?|using|utilizes?)",
"requires": r"(?:requires?|needs?|depends?\s+on)",
"is_a": r"(?:is\s+a|is\s+an|are\s+a|are)",
"has": r"(?:has|have|contains?)",
"part_of": r"(?:part\s+of|belongs?\s+to|member\s+of)",
"causes": r"(?:causes?|leads?\s+to|results?\s+in)",
"prevents": r"(?:prevents?|avoids?|stops?)",
"solves": r"(?:solves?|fixes?|resolves?)",
}
def __init__(self) -> None:
"""Initialize extractor."""
self._compiled_patterns = {
pred: re.compile(pattern, re.IGNORECASE)
for pred, pattern in self.PREDICATE_PATTERNS.items()
}
def extract_from_episode(
self,
episode: Episode,
context: ExtractionContext | None = None,
) -> list[ExtractedFact]:
"""
Extract facts from an episode.
Args:
episode: Episode to extract from
context: Optional extraction context
Returns:
List of extracted facts
"""
ctx = context or ExtractionContext()
facts: list[ExtractedFact] = []
# Extract from task description
task_facts = self._extract_from_text(
episode.task_description,
source_prefix=episode.task_type,
)
facts.extend(task_facts)
# Extract from lessons learned
for lesson in episode.lessons_learned:
lesson_facts = self._extract_from_lesson(lesson, episode)
facts.extend(lesson_facts)
# Extract outcome-based facts
outcome_facts = self._extract_outcome_facts(episode)
facts.extend(outcome_facts)
# Limit and filter
facts = [f for f in facts if f.confidence >= ctx.min_confidence]
facts = facts[: ctx.max_facts_per_source]
logger.debug(f"Extracted {len(facts)} facts from episode {episode.id}")
return facts
def _extract_from_text(
self,
text: str,
source_prefix: str = "",
) -> list[ExtractedFact]:
"""Extract facts from free-form text using pattern matching."""
facts: list[ExtractedFact] = []
if not text or len(text) < 10:
return facts
# Split into sentences
sentences = re.split(r"[.!?]+", text)
for sentence in sentences:
sentence = sentence.strip()
if len(sentence) < 10:
continue
# Try to match predicate patterns
for predicate, pattern in self._compiled_patterns.items():
match = pattern.search(sentence)
if match:
# Extract subject (text before predicate)
subject = sentence[: match.start()].strip()
# Extract object (text after predicate)
obj = sentence[match.end() :].strip()
if len(subject) > 2 and len(obj) > 2:
facts.append(
ExtractedFact(
subject=subject[:200], # Limit length
predicate=predicate,
object=obj[:500],
confidence=0.6, # Medium confidence for pattern matching
source_text=sentence,
)
)
break # One fact per sentence
return facts
def _extract_from_lesson(
self,
lesson: str,
episode: Episode,
) -> list[ExtractedFact]:
"""Extract facts from a lesson learned."""
facts: list[ExtractedFact] = []
if not lesson or len(lesson) < 10:
return facts
# Lessons are typically in the form "Always do X" or "Never do Y"
# or "When X, do Y"
# Direct lesson fact
facts.append(
ExtractedFact(
subject=episode.task_type,
predicate="lesson_learned",
object=lesson,
confidence=0.8, # High confidence for explicit lessons
source_text=lesson,
metadata={"outcome": episode.outcome.value},
)
)
# Extract conditional patterns
conditional_match = re.match(
r"(?:when|if)\s+(.+?),\s*(.+)",
lesson,
re.IGNORECASE,
)
if conditional_match:
condition, action = conditional_match.groups()
facts.append(
ExtractedFact(
subject=condition.strip(),
predicate="requires_action",
object=action.strip(),
confidence=0.7,
source_text=lesson,
)
)
# Extract "always/never" patterns
always_match = re.match(
r"(?:always)\s+(.+)",
lesson,
re.IGNORECASE,
)
if always_match:
facts.append(
ExtractedFact(
subject=episode.task_type,
predicate="best_practice",
object=always_match.group(1).strip(),
confidence=0.85,
source_text=lesson,
)
)
never_match = re.match(
r"(?:never|avoid)\s+(.+)",
lesson,
re.IGNORECASE,
)
if never_match:
facts.append(
ExtractedFact(
subject=episode.task_type,
predicate="anti_pattern",
object=never_match.group(1).strip(),
confidence=0.85,
source_text=lesson,
)
)
return facts
def _extract_outcome_facts(
self,
episode: Episode,
) -> list[ExtractedFact]:
"""Extract facts based on episode outcome."""
facts: list[ExtractedFact] = []
# Create fact based on outcome
if episode.outcome == Outcome.SUCCESS:
if episode.outcome_details:
facts.append(
ExtractedFact(
subject=episode.task_type,
predicate="successful_approach",
object=episode.outcome_details[:500],
confidence=0.75,
source_text=episode.outcome_details,
)
)
elif episode.outcome == Outcome.FAILURE:
if episode.outcome_details:
facts.append(
ExtractedFact(
subject=episode.task_type,
predicate="known_failure_mode",
object=episode.outcome_details[:500],
confidence=0.8, # High confidence for failures
source_text=episode.outcome_details,
)
)
return facts
def extract_from_text(
self,
text: str,
context: ExtractionContext | None = None,
) -> list[ExtractedFact]:
"""
Extract facts from arbitrary text.
Args:
text: Text to extract from
context: Optional extraction context
Returns:
List of extracted facts
"""
ctx = context or ExtractionContext()
facts = self._extract_from_text(text)
# Filter by confidence
facts = [f for f in facts if f.confidence >= ctx.min_confidence]
return facts[: ctx.max_facts_per_source]
# Singleton extractor instance
_extractor: FactExtractor | None = None
def get_fact_extractor() -> FactExtractor:
"""Get the singleton fact extractor instance."""
global _extractor
if _extractor is None:
_extractor = FactExtractor()
return _extractor

View File

@@ -0,0 +1,742 @@
# app/services/memory/semantic/memory.py
"""
Semantic Memory Implementation.
Provides fact storage and retrieval using subject-predicate-object triples.
Supports semantic search, confidence scoring, and fact reinforcement.
"""
import logging
import time
from datetime import UTC, datetime
from typing import Any
from uuid import UUID
from sqlalchemy import and_, desc, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory.fact import Fact as FactModel
from app.services.memory.config import get_memory_settings
from app.services.memory.types import Episode, Fact, FactCreate, RetrievalResult
logger = logging.getLogger(__name__)
def _model_to_fact(model: FactModel) -> Fact:
"""Convert SQLAlchemy model to Fact dataclass."""
# SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime
# they return actual values. We use type: ignore to handle this mismatch.
return Fact(
id=model.id, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
subject=model.subject, # type: ignore[arg-type]
predicate=model.predicate, # type: ignore[arg-type]
object=model.object, # type: ignore[arg-type]
confidence=model.confidence, # type: ignore[arg-type]
source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type]
first_learned=model.first_learned, # type: ignore[arg-type]
last_reinforced=model.last_reinforced, # type: ignore[arg-type]
reinforcement_count=model.reinforcement_count, # type: ignore[arg-type]
embedding=None, # Don't expose raw embedding
created_at=model.created_at, # type: ignore[arg-type]
updated_at=model.updated_at, # type: ignore[arg-type]
)
class SemanticMemory:
"""
Semantic Memory Service.
Provides fact storage and retrieval:
- Store facts as subject-predicate-object triples
- Semantic search over facts
- Entity-based retrieval
- Confidence scoring and decay
- Fact reinforcement on repeated learning
- Conflict resolution
Performance target: <100ms P95 for retrieval
"""
def __init__(
self,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> None:
"""
Initialize semantic memory.
Args:
session: Database session
embedding_generator: Optional embedding generator for semantic search
"""
self._session = session
self._embedding_generator = embedding_generator
self._settings = get_memory_settings()
@classmethod
async def create(
cls,
session: AsyncSession,
embedding_generator: Any | None = None,
) -> "SemanticMemory":
"""
Factory method to create SemanticMemory.
Args:
session: Database session
embedding_generator: Optional embedding generator
Returns:
Configured SemanticMemory instance
"""
return cls(session=session, embedding_generator=embedding_generator)
# =========================================================================
# Fact Storage
# =========================================================================
async def store_fact(self, fact: FactCreate) -> Fact:
"""
Store a new fact or reinforce an existing one.
If a fact with the same triple (subject, predicate, object) exists
in the same scope, it will be reinforced instead of duplicated.
Args:
fact: Fact data to store
Returns:
The created or reinforced fact
"""
# Check for existing fact with same triple
existing = await self._find_existing_fact(
project_id=fact.project_id,
subject=fact.subject,
predicate=fact.predicate,
object=fact.object,
)
if existing is not None:
# Reinforce existing fact
return await self.reinforce_fact(
existing.id, # type: ignore[arg-type]
source_episode_ids=fact.source_episode_ids,
)
# Create new fact
now = datetime.now(UTC)
# Generate embedding if possible
embedding = None
if self._embedding_generator is not None:
embedding_text = self._create_embedding_text(fact)
embedding = await self._embedding_generator.generate(embedding_text)
model = FactModel(
project_id=fact.project_id,
subject=fact.subject,
predicate=fact.predicate,
object=fact.object,
confidence=fact.confidence,
source_episode_ids=fact.source_episode_ids,
first_learned=now,
last_reinforced=now,
reinforcement_count=1,
embedding=embedding,
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
logger.info(
f"Stored new fact: {fact.subject} - {fact.predicate} - {fact.object[:50]}..."
)
return _model_to_fact(model)
async def _find_existing_fact(
self,
project_id: UUID | None,
subject: str,
predicate: str,
object: str,
) -> FactModel | None:
"""Find an existing fact with the same triple in the same scope."""
query = select(FactModel).where(
and_(
FactModel.subject == subject,
FactModel.predicate == predicate,
FactModel.object == object,
)
)
if project_id is not None:
query = query.where(FactModel.project_id == project_id)
else:
query = query.where(FactModel.project_id.is_(None))
result = await self._session.execute(query)
return result.scalar_one_or_none()
def _create_embedding_text(self, fact: FactCreate) -> str:
"""Create text for embedding from fact data."""
return f"{fact.subject} {fact.predicate} {fact.object}"
# =========================================================================
# Fact Retrieval
# =========================================================================
async def search_facts(
self,
query: str,
project_id: UUID | None = None,
limit: int = 10,
min_confidence: float | None = None,
) -> list[Fact]:
"""
Search for facts semantically similar to the query.
Args:
query: Search query
project_id: Optional project to search within
limit: Maximum results
min_confidence: Optional minimum confidence filter
Returns:
List of matching facts
"""
result = await self._search_facts_with_metadata(
query=query,
project_id=project_id,
limit=limit,
min_confidence=min_confidence,
)
return result.items
async def _search_facts_with_metadata(
self,
query: str,
project_id: UUID | None = None,
limit: int = 10,
min_confidence: float | None = None,
) -> RetrievalResult[Fact]:
"""Search facts with full result metadata."""
start_time = time.perf_counter()
min_conf = min_confidence or self._settings.semantic_min_confidence
# Build base query
stmt = (
select(FactModel)
.where(FactModel.confidence >= min_conf)
.order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))
.limit(limit)
)
# Apply project filter
if project_id is not None:
# Include both project-specific and global facts
stmt = stmt.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
# TODO: Implement proper vector similarity search when pgvector is integrated
# For now, do text-based search on subject/predicate/object
search_terms = query.lower().split()
if search_terms:
conditions = []
for term in search_terms[:5]: # Limit to 5 terms
term_pattern = f"%{term}%"
conditions.append(
or_(
FactModel.subject.ilike(term_pattern),
FactModel.predicate.ilike(term_pattern),
FactModel.object.ilike(term_pattern),
)
)
if conditions:
stmt = stmt.where(or_(*conditions))
result = await self._session.execute(stmt)
models = list(result.scalars().all())
latency_ms = (time.perf_counter() - start_time) * 1000
return RetrievalResult(
items=[_model_to_fact(m) for m in models],
total_count=len(models),
query=query,
retrieval_type="semantic",
latency_ms=latency_ms,
metadata={"min_confidence": min_conf},
)
async def get_by_entity(
self,
entity: str,
project_id: UUID | None = None,
limit: int = 20,
) -> list[Fact]:
"""
Get facts related to an entity (as subject or object).
Args:
entity: Entity to search for
project_id: Optional project to search within
limit: Maximum results
Returns:
List of facts mentioning the entity
"""
start_time = time.perf_counter()
stmt = (
select(FactModel)
.where(
or_(
FactModel.subject.ilike(f"%{entity}%"),
FactModel.object.ilike(f"%{entity}%"),
)
)
.order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))
.limit(limit)
)
if project_id is not None:
stmt = stmt.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(stmt)
models = list(result.scalars().all())
latency_ms = (time.perf_counter() - start_time) * 1000
logger.debug(
f"get_by_entity({entity}) returned {len(models)} facts in {latency_ms:.1f}ms"
)
return [_model_to_fact(m) for m in models]
async def get_by_subject(
self,
subject: str,
project_id: UUID | None = None,
predicate: str | None = None,
limit: int = 20,
) -> list[Fact]:
"""
Get facts with a specific subject.
Args:
subject: Subject to search for
project_id: Optional project to search within
predicate: Optional predicate filter
limit: Maximum results
Returns:
List of facts with matching subject
"""
stmt = (
select(FactModel)
.where(FactModel.subject == subject)
.order_by(desc(FactModel.confidence))
.limit(limit)
)
if predicate is not None:
stmt = stmt.where(FactModel.predicate == predicate)
if project_id is not None:
stmt = stmt.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(stmt)
models = list(result.scalars().all())
return [_model_to_fact(m) for m in models]
async def get_by_id(self, fact_id: UUID) -> Fact | None:
"""Get a fact by ID."""
query = select(FactModel).where(FactModel.id == fact_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
return _model_to_fact(model) if model else None
# =========================================================================
# Fact Reinforcement
# =========================================================================
async def reinforce_fact(
self,
fact_id: UUID,
confidence_boost: float = 0.1,
source_episode_ids: list[UUID] | None = None,
) -> Fact:
"""
Reinforce a fact, increasing its confidence.
Args:
fact_id: Fact to reinforce
confidence_boost: Amount to increase confidence (default 0.1)
source_episode_ids: Additional source episodes
Returns:
Updated fact
Raises:
ValueError: If fact not found
"""
query = select(FactModel).where(FactModel.id == fact_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
raise ValueError(f"Fact not found: {fact_id}")
# Calculate new confidence (max 1.0)
current_confidence: float = model.confidence # type: ignore[assignment]
new_confidence = min(1.0, current_confidence + confidence_boost)
# Merge source episode IDs
current_sources: list[UUID] = model.source_episode_ids or [] # type: ignore[assignment]
if source_episode_ids:
# Add new sources, avoiding duplicates
new_sources = list(set(current_sources + source_episode_ids))
else:
new_sources = current_sources
now = datetime.now(UTC)
stmt = (
update(FactModel)
.where(FactModel.id == fact_id)
.values(
confidence=new_confidence,
source_episode_ids=new_sources,
last_reinforced=now,
reinforcement_count=FactModel.reinforcement_count + 1,
updated_at=now,
)
.returning(FactModel)
)
result = await self._session.execute(stmt)
updated_model = result.scalar_one()
await self._session.flush()
logger.info(
f"Reinforced fact {fact_id}: confidence {current_confidence:.2f} -> {new_confidence:.2f}"
)
return _model_to_fact(updated_model)
async def deprecate_fact(
self,
fact_id: UUID,
reason: str,
new_confidence: float = 0.0,
) -> Fact | None:
"""
Deprecate a fact by lowering its confidence.
Args:
fact_id: Fact to deprecate
reason: Reason for deprecation
new_confidence: New confidence level (default 0.0)
Returns:
Updated fact or None if not found
"""
query = select(FactModel).where(FactModel.id == fact_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return None
now = datetime.now(UTC)
stmt = (
update(FactModel)
.where(FactModel.id == fact_id)
.values(
confidence=max(0.0, new_confidence),
updated_at=now,
)
.returning(FactModel)
)
result = await self._session.execute(stmt)
updated_model = result.scalar_one_or_none()
await self._session.flush()
logger.info(f"Deprecated fact {fact_id}: {reason}")
return _model_to_fact(updated_model) if updated_model else None
# =========================================================================
# Fact Extraction from Episodes
# =========================================================================
async def extract_facts_from_episode(
self,
episode: Episode,
) -> list[Fact]:
"""
Extract facts from an episode.
This is a placeholder for LLM-based fact extraction.
In production, this would call an LLM to analyze the episode
and extract subject-predicate-object triples.
Args:
episode: Episode to extract facts from
Returns:
List of extracted facts
"""
# For now, extract basic facts from lessons learned
extracted_facts: list[Fact] = []
for lesson in episode.lessons_learned:
if len(lesson) > 10: # Skip very short lessons
fact_create = FactCreate(
subject=episode.task_type,
predicate="lesson_learned",
object=lesson,
confidence=0.7, # Lessons start with moderate confidence
project_id=episode.project_id,
source_episode_ids=[episode.id],
)
fact = await self.store_fact(fact_create)
extracted_facts.append(fact)
logger.debug(
f"Extracted {len(extracted_facts)} facts from episode {episode.id}"
)
return extracted_facts
# =========================================================================
# Conflict Resolution
# =========================================================================
async def resolve_conflict(
self,
fact_ids: list[UUID],
keep_fact_id: UUID | None = None,
) -> Fact | None:
"""
Resolve a conflict between multiple facts.
If keep_fact_id is specified, that fact is kept and others are deprecated.
Otherwise, the fact with highest confidence is kept.
Args:
fact_ids: IDs of conflicting facts
keep_fact_id: Optional ID of fact to keep
Returns:
The winning fact, or None if no facts found
"""
if not fact_ids:
return None
# Load all facts
query = select(FactModel).where(FactModel.id.in_(fact_ids))
result = await self._session.execute(query)
models = list(result.scalars().all())
if not models:
return None
# Determine winner
if keep_fact_id is not None:
winner = next((m for m in models if m.id == keep_fact_id), None)
if winner is None:
# Fallback to highest confidence
winner = max(models, key=lambda m: m.confidence)
else:
# Keep the fact with highest confidence
winner = max(models, key=lambda m: m.confidence)
# Deprecate losers
for model in models:
if model.id != winner.id:
await self.deprecate_fact(
model.id, # type: ignore[arg-type]
reason=f"Conflict resolution: superseded by {winner.id}",
)
logger.info(
f"Resolved conflict between {len(fact_ids)} facts, keeping {winner.id}"
)
return _model_to_fact(winner)
# =========================================================================
# Confidence Decay
# =========================================================================
async def apply_confidence_decay(
self,
project_id: UUID | None = None,
decay_factor: float = 0.01,
) -> int:
"""
Apply confidence decay to facts that haven't been reinforced recently.
Args:
project_id: Optional project to apply decay to
decay_factor: Decay factor per day (default 0.01)
Returns:
Number of facts affected
"""
now = datetime.now(UTC)
decay_days = self._settings.semantic_confidence_decay_days
min_conf = self._settings.semantic_min_confidence
# Calculate cutoff date
from datetime import timedelta
cutoff = now - timedelta(days=decay_days)
# Find facts needing decay
query = select(FactModel).where(
and_(
FactModel.last_reinforced < cutoff,
FactModel.confidence > min_conf,
)
)
if project_id is not None:
query = query.where(FactModel.project_id == project_id)
result = await self._session.execute(query)
models = list(result.scalars().all())
# Apply decay
updated_count = 0
for model in models:
# Calculate days since last reinforcement
days_since: float = (now - model.last_reinforced).days
# Calculate decay: exponential decay based on days
decay = decay_factor * (days_since - decay_days)
new_confidence = max(min_conf, model.confidence - decay)
if new_confidence != model.confidence:
await self._session.execute(
update(FactModel)
.where(FactModel.id == model.id)
.values(confidence=new_confidence, updated_at=now)
)
updated_count += 1
await self._session.flush()
logger.info(f"Applied confidence decay to {updated_count} facts")
return updated_count
# =========================================================================
# Statistics
# =========================================================================
async def get_stats(self, project_id: UUID | None = None) -> dict[str, Any]:
"""
Get statistics about semantic memory.
Args:
project_id: Optional project to get stats for
Returns:
Dictionary with statistics
"""
# Get all facts for this scope
query = select(FactModel)
if project_id is not None:
query = query.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
if not models:
return {
"total_facts": 0,
"avg_confidence": 0.0,
"avg_reinforcement_count": 0.0,
"high_confidence_count": 0,
"low_confidence_count": 0,
}
confidences = [m.confidence for m in models]
reinforcements = [m.reinforcement_count for m in models]
return {
"total_facts": len(models),
"avg_confidence": sum(confidences) / len(confidences),
"avg_reinforcement_count": sum(reinforcements) / len(reinforcements),
"high_confidence_count": sum(1 for c in confidences if c >= 0.8),
"low_confidence_count": sum(1 for c in confidences if c < 0.5),
}
async def count(self, project_id: UUID | None = None) -> int:
"""
Count facts in scope.
Args:
project_id: Optional project to count for
Returns:
Number of facts
"""
query = select(FactModel)
if project_id is not None:
query = query.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
return len(list(result.scalars().all()))
async def delete(self, fact_id: UUID) -> bool:
"""
Delete a fact.
Args:
fact_id: Fact to delete
Returns:
True if deleted, False if not found
"""
query = select(FactModel).where(FactModel.id == fact_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return False
await self._session.delete(model)
await self._session.flush()
logger.info(f"Deleted fact {fact_id}")
return True

View File

@@ -0,0 +1,363 @@
# app/services/memory/semantic/verification.py
"""
Fact Verification.
Provides utilities for verifying facts, detecting conflicts,
and managing fact consistency.
"""
import logging
from dataclasses import dataclass, field
from typing import Any, ClassVar
from uuid import UUID
from sqlalchemy import and_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.memory.fact import Fact as FactModel
from app.services.memory.types import Fact
logger = logging.getLogger(__name__)
@dataclass
class VerificationResult:
"""Result of fact verification."""
is_valid: bool
confidence_adjustment: float = 0.0
conflicts: list["FactConflict"] = field(default_factory=list)
supporting_facts: list[Fact] = field(default_factory=list)
messages: list[str] = field(default_factory=list)
@dataclass
class FactConflict:
"""Represents a conflict between two facts."""
fact_a_id: UUID
fact_b_id: UUID
conflict_type: str # "contradiction", "superseded", "partial_overlap"
description: str
suggested_resolution: str | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"fact_a_id": str(self.fact_a_id),
"fact_b_id": str(self.fact_b_id),
"conflict_type": self.conflict_type,
"description": self.description,
"suggested_resolution": self.suggested_resolution,
}
class FactVerifier:
"""
Verifies facts and detects conflicts.
Provides methods to:
- Check if a fact conflicts with existing facts
- Find supporting evidence for a fact
- Detect contradictions in the fact base
"""
# Predicates that are opposites/contradictions
CONTRADICTORY_PREDICATES: ClassVar[set[tuple[str, str]]] = {
("uses", "does_not_use"),
("requires", "does_not_require"),
("is_a", "is_not_a"),
("causes", "prevents"),
("allows", "prevents"),
("supports", "does_not_support"),
("best_practice", "anti_pattern"),
}
def __init__(self, session: AsyncSession) -> None:
"""Initialize verifier with database session."""
self._session = session
async def verify_fact(
self,
subject: str,
predicate: str,
obj: str,
project_id: UUID | None = None,
) -> VerificationResult:
"""
Verify a fact against existing facts.
Args:
subject: Fact subject
predicate: Fact predicate
obj: Fact object
project_id: Optional project scope
Returns:
VerificationResult with verification details
"""
result = VerificationResult(is_valid=True)
# Check for direct contradictions
conflicts = await self._find_contradictions(
subject=subject,
predicate=predicate,
obj=obj,
project_id=project_id,
)
result.conflicts = conflicts
if conflicts:
result.is_valid = False
result.messages.append(f"Found {len(conflicts)} conflicting fact(s)")
# Reduce confidence based on conflicts
result.confidence_adjustment = -0.1 * len(conflicts)
# Find supporting facts
supporting = await self._find_supporting_facts(
subject=subject,
predicate=predicate,
project_id=project_id,
)
result.supporting_facts = supporting
if supporting:
result.messages.append(f"Found {len(supporting)} supporting fact(s)")
# Boost confidence based on support
result.confidence_adjustment += 0.05 * min(len(supporting), 3)
return result
async def _find_contradictions(
self,
subject: str,
predicate: str,
obj: str,
project_id: UUID | None = None,
) -> list[FactConflict]:
"""Find facts that contradict the given fact."""
conflicts: list[FactConflict] = []
# Find opposite predicates
opposite_predicates = self._get_opposite_predicates(predicate)
if not opposite_predicates:
return conflicts
# Search for contradicting facts
query = select(FactModel).where(
and_(
FactModel.subject == subject,
FactModel.predicate.in_(opposite_predicates),
)
)
if project_id is not None:
query = query.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
for model in models:
conflicts.append(
FactConflict(
fact_a_id=model.id, # type: ignore[arg-type]
fact_b_id=UUID(
"00000000-0000-0000-0000-000000000000"
), # Placeholder for new fact
conflict_type="contradiction",
description=(
f"'{subject} {predicate} {obj}' contradicts "
f"'{model.subject} {model.predicate} {model.object}'"
),
suggested_resolution="Keep fact with higher confidence",
)
)
return conflicts
def _get_opposite_predicates(self, predicate: str) -> list[str]:
"""Get predicates that are opposite to the given predicate."""
opposites: list[str] = []
for pair in self.CONTRADICTORY_PREDICATES:
if predicate in pair:
opposites.extend(p for p in pair if p != predicate)
return opposites
async def _find_supporting_facts(
self,
subject: str,
predicate: str,
project_id: UUID | None = None,
) -> list[Fact]:
"""Find facts that support the given fact."""
# Find facts with same subject and predicate
query = (
select(FactModel)
.where(
and_(
FactModel.subject == subject,
FactModel.predicate == predicate,
FactModel.confidence >= 0.5,
)
)
.limit(10)
)
if project_id is not None:
query = query.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
return [self._model_to_fact(m) for m in models]
async def find_all_conflicts(
self,
project_id: UUID | None = None,
) -> list[FactConflict]:
"""
Find all conflicts in the fact base.
Args:
project_id: Optional project scope
Returns:
List of all detected conflicts
"""
conflicts: list[FactConflict] = []
# Get all facts
query = select(FactModel)
if project_id is not None:
query = query.where(
or_(
FactModel.project_id == project_id,
FactModel.project_id.is_(None),
)
)
result = await self._session.execute(query)
models = list(result.scalars().all())
# Check each pair for conflicts
for i, fact_a in enumerate(models):
for fact_b in models[i + 1 :]:
conflict = self._check_pair_conflict(fact_a, fact_b)
if conflict:
conflicts.append(conflict)
logger.info(f"Found {len(conflicts)} conflicts in fact base")
return conflicts
def _check_pair_conflict(
self,
fact_a: FactModel,
fact_b: FactModel,
) -> FactConflict | None:
"""Check if two facts conflict."""
# Same subject?
if fact_a.subject != fact_b.subject:
return None
# Contradictory predicates?
opposite = self._get_opposite_predicates(fact_a.predicate) # type: ignore[arg-type]
if fact_b.predicate not in opposite:
return None
return FactConflict(
fact_a_id=fact_a.id, # type: ignore[arg-type]
fact_b_id=fact_b.id, # type: ignore[arg-type]
conflict_type="contradiction",
description=(
f"'{fact_a.subject} {fact_a.predicate} {fact_a.object}' "
f"contradicts '{fact_b.subject} {fact_b.predicate} {fact_b.object}'"
),
suggested_resolution="Deprecate fact with lower confidence",
)
async def get_fact_reliability_score(
self,
fact_id: UUID,
) -> float:
"""
Calculate a reliability score for a fact.
Based on:
- Confidence score
- Number of reinforcements
- Number of supporting facts
- Absence of conflicts
Args:
fact_id: Fact to score
Returns:
Reliability score (0.0 to 1.0)
"""
query = select(FactModel).where(FactModel.id == fact_id)
result = await self._session.execute(query)
model = result.scalar_one_or_none()
if model is None:
return 0.0
# Base score from confidence - explicitly typed to avoid Column type issues
score: float = float(model.confidence)
# Boost for reinforcements (diminishing returns)
reinforcement_boost = min(0.2, float(model.reinforcement_count) * 0.02)
score += reinforcement_boost
# Find supporting facts
supporting = await self._find_supporting_facts(
subject=model.subject, # type: ignore[arg-type]
predicate=model.predicate, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
)
support_boost = min(0.1, len(supporting) * 0.02)
score += support_boost
# Check for conflicts
conflicts = await self._find_contradictions(
subject=model.subject, # type: ignore[arg-type]
predicate=model.predicate, # type: ignore[arg-type]
obj=model.object, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
)
conflict_penalty = min(0.3, len(conflicts) * 0.1)
score -= conflict_penalty
# Clamp to valid range
return max(0.0, min(1.0, score))
def _model_to_fact(self, model: FactModel) -> Fact:
"""Convert SQLAlchemy model to Fact dataclass."""
return Fact(
id=model.id, # type: ignore[arg-type]
project_id=model.project_id, # type: ignore[arg-type]
subject=model.subject, # type: ignore[arg-type]
predicate=model.predicate, # type: ignore[arg-type]
object=model.object, # type: ignore[arg-type]
confidence=model.confidence, # type: ignore[arg-type]
source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type]
first_learned=model.first_learned, # type: ignore[arg-type]
last_reinforced=model.last_reinforced, # type: ignore[arg-type]
reinforcement_count=model.reinforcement_count, # type: ignore[arg-type]
embedding=None,
created_at=model.created_at, # type: ignore[arg-type]
updated_at=model.updated_at, # type: ignore[arg-type]
)

View File

@@ -0,0 +1,2 @@
# tests/unit/services/memory/semantic/__init__.py
"""Unit tests for semantic memory service."""

View File

@@ -0,0 +1,263 @@
# tests/unit/services/memory/semantic/test_extraction.py
"""Unit tests for fact extraction."""
from datetime import UTC, datetime
from uuid import uuid4
import pytest
from app.services.memory.semantic.extraction import (
ExtractedFact,
ExtractionContext,
FactExtractor,
get_fact_extractor,
)
from app.services.memory.types import Episode, Outcome
def create_test_episode(
lessons_learned: list[str] | None = None,
outcome: Outcome = Outcome.SUCCESS,
task_type: str = "code_review",
task_description: str = "Review the authentication module",
outcome_details: str = "",
) -> Episode:
"""Create a test episode for extraction tests."""
return Episode(
id=uuid4(),
project_id=uuid4(),
agent_instance_id=None,
agent_type_id=None,
session_id="test-session",
task_type=task_type,
task_description=task_description,
actions=[],
context_summary="Test context",
outcome=outcome,
outcome_details=outcome_details,
duration_seconds=60.0,
tokens_used=500,
lessons_learned=lessons_learned or [],
importance_score=0.7,
embedding=None,
occurred_at=datetime.now(UTC),
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
class TestExtractedFact:
"""Tests for ExtractedFact dataclass."""
def test_to_fact_create(self) -> None:
"""Test converting ExtractedFact to FactCreate."""
extracted = ExtractedFact(
subject="Python",
predicate="uses",
object="dynamic typing",
confidence=0.8,
)
fact_create = extracted.to_fact_create(
project_id=uuid4(),
source_episode_ids=[uuid4()],
)
assert fact_create.subject == "Python"
assert fact_create.predicate == "uses"
assert fact_create.object == "dynamic typing"
assert fact_create.confidence == 0.8
def test_to_fact_create_defaults(self) -> None:
"""Test to_fact_create with default values."""
extracted = ExtractedFact(
subject="A",
predicate="B",
object="C",
confidence=0.5,
)
fact_create = extracted.to_fact_create()
assert fact_create.project_id is None
assert fact_create.source_episode_ids == []
class TestFactExtractor:
"""Tests for FactExtractor class."""
@pytest.fixture
def extractor(self) -> FactExtractor:
"""Create a fact extractor."""
return FactExtractor()
def test_extract_from_episode_with_lessons(
self,
extractor: FactExtractor,
) -> None:
"""Test extracting facts from episode with lessons."""
episode = create_test_episode(
lessons_learned=[
"Always validate user input before processing",
"Use parameterized queries to prevent SQL injection",
]
)
facts = extractor.extract_from_episode(episode)
assert len(facts) > 0
# Should have lesson_learned predicates
lesson_facts = [f for f in facts if f.predicate == "lesson_learned"]
assert len(lesson_facts) >= 2
def test_extract_from_episode_with_always_pattern(
self,
extractor: FactExtractor,
) -> None:
"""Test extracting 'always' pattern lessons."""
episode = create_test_episode(
lessons_learned=["Always close file handles properly"]
)
facts = extractor.extract_from_episode(episode)
best_practices = [f for f in facts if f.predicate == "best_practice"]
assert len(best_practices) >= 1
assert any("close file handles" in f.object for f in best_practices)
def test_extract_from_episode_with_never_pattern(
self,
extractor: FactExtractor,
) -> None:
"""Test extracting 'never' pattern lessons."""
episode = create_test_episode(
lessons_learned=["Never store passwords in plain text"]
)
facts = extractor.extract_from_episode(episode)
anti_patterns = [f for f in facts if f.predicate == "anti_pattern"]
assert len(anti_patterns) >= 1
def test_extract_from_episode_with_conditional_pattern(
self,
extractor: FactExtractor,
) -> None:
"""Test extracting conditional lessons."""
episode = create_test_episode(
lessons_learned=["When handling errors, log the stack trace"]
)
facts = extractor.extract_from_episode(episode)
conditional = [f for f in facts if f.predicate == "requires_action"]
assert len(conditional) >= 1
def test_extract_outcome_facts_success(
self,
extractor: FactExtractor,
) -> None:
"""Test extracting facts from successful episode."""
episode = create_test_episode(
outcome=Outcome.SUCCESS,
outcome_details="Deployed to production without issues",
)
facts = extractor.extract_from_episode(episode)
success_facts = [f for f in facts if f.predicate == "successful_approach"]
assert len(success_facts) >= 1
def test_extract_outcome_facts_failure(
self,
extractor: FactExtractor,
) -> None:
"""Test extracting facts from failed episode."""
episode = create_test_episode(
outcome=Outcome.FAILURE,
outcome_details="Connection timeout during deployment",
)
facts = extractor.extract_from_episode(episode)
failure_facts = [f for f in facts if f.predicate == "known_failure_mode"]
assert len(failure_facts) >= 1
def test_extract_from_text_uses_pattern(
self,
extractor: FactExtractor,
) -> None:
"""Test extracting 'uses' pattern from text."""
text = "FastAPI uses Starlette for ASGI support."
facts = extractor.extract_from_text(text)
assert len(facts) >= 1
uses_facts = [f for f in facts if f.predicate == "uses"]
assert len(uses_facts) >= 1
def test_extract_from_text_requires_pattern(
self,
extractor: FactExtractor,
) -> None:
"""Test extracting 'requires' pattern from text."""
text = "This feature requires Python 3.10 or higher."
facts = extractor.extract_from_text(text)
requires_facts = [f for f in facts if f.predicate == "requires"]
assert len(requires_facts) >= 1
def test_extract_from_text_empty(
self,
extractor: FactExtractor,
) -> None:
"""Test extracting from empty text."""
facts = extractor.extract_from_text("")
assert facts == []
def test_extract_from_text_short(
self,
extractor: FactExtractor,
) -> None:
"""Test extracting from too-short text."""
facts = extractor.extract_from_text("Hi.")
assert facts == []
def test_extract_with_context(
self,
extractor: FactExtractor,
) -> None:
"""Test extraction with custom context."""
episode = create_test_episode(lessons_learned=["Low confidence lesson"])
context = ExtractionContext(
min_confidence=0.9, # High threshold
max_facts_per_source=2,
)
facts = extractor.extract_from_episode(episode, context)
# Should filter out low confidence facts
for fact in facts:
assert fact.confidence >= 0.9 or len(facts) <= 2
class TestGetFactExtractor:
"""Tests for singleton getter."""
def test_get_fact_extractor_returns_instance(self) -> None:
"""Test that get_fact_extractor returns an instance."""
extractor = get_fact_extractor()
assert extractor is not None
assert isinstance(extractor, FactExtractor)
def test_get_fact_extractor_returns_same_instance(self) -> None:
"""Test that get_fact_extractor returns singleton."""
extractor1 = get_fact_extractor()
extractor2 = get_fact_extractor()
assert extractor1 is extractor2

View File

@@ -0,0 +1,446 @@
# tests/unit/services/memory/semantic/test_memory.py
"""Unit tests for SemanticMemory class."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.services.memory.semantic.memory import SemanticMemory
from app.services.memory.types import FactCreate
def create_mock_fact_model(
project_id=None,
subject="FastAPI",
predicate="uses",
obj="Starlette",
confidence=0.8,
):
"""Create a mock fact model for testing."""
mock = MagicMock()
mock.id = uuid4()
mock.project_id = project_id
mock.subject = subject
mock.predicate = predicate
mock.object = obj
mock.confidence = confidence
mock.source_episode_ids = []
mock.first_learned = datetime.now(UTC)
mock.last_reinforced = datetime.now(UTC)
mock.reinforcement_count = 1
mock.embedding = None
mock.created_at = datetime.now(UTC)
mock.updated_at = datetime.now(UTC)
return mock
class TestSemanticMemoryInit:
"""Tests for SemanticMemory initialization."""
def test_init_creates_memory(self) -> None:
"""Test that init creates memory instance."""
mock_session = AsyncMock()
memory = SemanticMemory(session=mock_session)
assert memory._session is mock_session
def test_init_with_embedding_generator(self) -> None:
"""Test init with embedding generator."""
mock_session = AsyncMock()
mock_embedding_gen = AsyncMock()
memory = SemanticMemory(
session=mock_session, embedding_generator=mock_embedding_gen
)
assert memory._embedding_generator is mock_embedding_gen
@pytest.mark.asyncio
async def test_create_factory_method(self) -> None:
"""Test create factory method."""
mock_session = AsyncMock()
memory = await SemanticMemory.create(session=mock_session)
assert memory is not None
assert memory._session is mock_session
class TestSemanticMemoryStoreFact:
"""Tests for fact storage methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
session.add = MagicMock()
session.flush = AsyncMock()
session.refresh = AsyncMock()
# Mock no existing fact found
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
session.execute.return_value = mock_result
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
"""Create a SemanticMemory instance."""
return SemanticMemory(session=mock_session)
@pytest.mark.asyncio
async def test_store_new_fact(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test storing a new fact."""
fact_data = FactCreate(
subject="Python",
predicate="is_a",
object="programming language",
confidence=0.9,
project_id=uuid4(),
)
result = await memory.store_fact(fact_data)
assert result.subject == "Python"
assert result.predicate == "is_a"
assert result.object == "programming language"
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
@pytest.mark.asyncio
async def test_store_fact_reinforces_existing(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test that storing duplicate fact reinforces existing."""
# Mock existing fact found - needs to be found first
existing_mock = create_mock_fact_model(confidence=0.7)
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
# Second find for reinforce_fact
find_for_reinforce = MagicMock()
find_for_reinforce.scalar_one_or_none.return_value = existing_mock
# Mock update result - returns the updated mock
updated_mock = create_mock_fact_model(confidence=0.8)
update_result = MagicMock()
update_result.scalar_one.return_value = updated_mock
mock_session.execute.side_effect = [
find_result, # _find_existing_fact
find_for_reinforce, # reinforce_fact query
update_result, # reinforce_fact update
]
fact_data = FactCreate(
subject="FastAPI",
predicate="uses",
object="Starlette",
)
_ = await memory.store_fact(fact_data)
# Should have called execute three times (find + find + update)
assert mock_session.execute.call_count == 3
class TestSemanticMemorySearch:
"""Tests for fact search methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
session.execute.return_value = mock_result
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
"""Create a SemanticMemory instance."""
return SemanticMemory(session=mock_session)
@pytest.mark.asyncio
async def test_search_facts(
self,
memory: SemanticMemory,
) -> None:
"""Test searching for facts."""
results = await memory.search_facts("Python programming")
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_search_facts_with_project_filter(
self,
memory: SemanticMemory,
) -> None:
"""Test searching for facts with project filter."""
project_id = uuid4()
results = await memory.search_facts("Python", project_id=project_id)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_by_entity(
self,
memory: SemanticMemory,
) -> None:
"""Test getting facts by entity."""
results = await memory.get_by_entity("FastAPI")
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_by_subject(
self,
memory: SemanticMemory,
) -> None:
"""Test getting facts by subject."""
results = await memory.get_by_subject("Python")
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_get_by_id_not_found(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test get_by_id returns None when not found."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.get_by_id(uuid4())
assert result is None
class TestSemanticMemoryReinforcement:
"""Tests for fact reinforcement."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
"""Create a SemanticMemory instance."""
return SemanticMemory(session=mock_session)
@pytest.mark.asyncio
async def test_reinforce_fact(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test reinforcing a fact."""
existing_mock = create_mock_fact_model(confidence=0.7)
# First query: find fact
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
# Second query: update fact
updated_mock = create_mock_fact_model(confidence=0.8)
update_result = MagicMock()
update_result.scalar_one.return_value = updated_mock
mock_session.execute.side_effect = [find_result, update_result]
result = await memory.reinforce_fact(existing_mock.id, confidence_boost=0.1)
assert result.confidence == 0.8
@pytest.mark.asyncio
async def test_reinforce_fact_not_found(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test reinforcing a non-existent fact raises error."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
with pytest.raises(ValueError, match="Fact not found"):
await memory.reinforce_fact(uuid4())
@pytest.mark.asyncio
async def test_deprecate_fact(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test deprecating a fact."""
existing_mock = create_mock_fact_model(confidence=0.8)
find_result = MagicMock()
find_result.scalar_one_or_none.return_value = existing_mock
deprecated_mock = create_mock_fact_model(confidence=0.0)
update_result = MagicMock()
update_result.scalar_one_or_none.return_value = deprecated_mock
mock_session.execute.side_effect = [find_result, update_result]
result = await memory.deprecate_fact(existing_mock.id, reason="Outdated")
assert result is not None
assert result.confidence == 0.0
@pytest.mark.asyncio
async def test_deprecate_fact_not_found(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test deprecating non-existent fact returns None."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.deprecate_fact(uuid4(), reason="Test")
assert result is None
class TestSemanticMemoryConflictResolution:
"""Tests for conflict resolution."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
"""Create a SemanticMemory instance."""
return SemanticMemory(session=mock_session)
@pytest.mark.asyncio
async def test_resolve_conflict_empty_list(
self,
memory: SemanticMemory,
) -> None:
"""Test resolving conflict with empty list."""
result = await memory.resolve_conflict([])
assert result is None
@pytest.mark.asyncio
async def test_resolve_conflict_keeps_highest_confidence(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test that conflict resolution keeps highest confidence fact."""
fact_low = create_mock_fact_model(confidence=0.5)
fact_high = create_mock_fact_model(confidence=0.9)
# Mock finding the facts
find_result = MagicMock()
find_result.scalars.return_value.all.return_value = [fact_low, fact_high]
# Mock deprecation (find + update)
find_one_result = MagicMock()
find_one_result.scalar_one_or_none.return_value = fact_low
update_result = MagicMock()
update_result.scalar_one_or_none.return_value = fact_low
mock_session.execute.side_effect = [find_result, find_one_result, update_result]
result = await memory.resolve_conflict([fact_low.id, fact_high.id])
assert result is not None
assert result.confidence == 0.9
class TestSemanticMemoryStats:
"""Tests for statistics methods."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
return session
@pytest.fixture
def memory(self, mock_session: AsyncMock) -> SemanticMemory:
"""Create a SemanticMemory instance."""
return SemanticMemory(session=mock_session)
@pytest.mark.asyncio
async def test_get_stats_empty(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test getting stats for empty project."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
stats = await memory.get_stats(uuid4())
assert stats["total_facts"] == 0
assert stats["avg_confidence"] == 0.0
@pytest.mark.asyncio
async def test_count(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test counting facts."""
facts = [create_mock_fact_model() for _ in range(5)]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = facts
mock_session.execute.return_value = mock_result
count = await memory.count(uuid4())
assert count == 5
@pytest.mark.asyncio
async def test_delete(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test deleting a fact."""
existing_mock = create_mock_fact_model()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_mock
mock_session.execute.return_value = mock_result
mock_session.delete = AsyncMock()
result = await memory.delete(existing_mock.id)
assert result is True
mock_session.delete.assert_called_once()
@pytest.mark.asyncio
async def test_delete_not_found(
self,
memory: SemanticMemory,
mock_session: AsyncMock,
) -> None:
"""Test deleting non-existent fact."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
result = await memory.delete(uuid4())
assert result is False

View File

@@ -0,0 +1,298 @@
# tests/unit/services/memory/semantic/test_verification.py
"""Unit tests for fact verification."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4
import pytest
from app.services.memory.semantic.verification import (
FactConflict,
FactVerifier,
VerificationResult,
)
def create_mock_fact_model(
subject="FastAPI",
predicate="uses",
obj="Starlette",
confidence=0.8,
project_id=None,
):
"""Create a mock fact model for testing."""
mock = MagicMock()
mock.id = uuid4()
mock.project_id = project_id
mock.subject = subject
mock.predicate = predicate
mock.object = obj
mock.confidence = confidence
mock.source_episode_ids = []
mock.first_learned = datetime.now(UTC)
mock.last_reinforced = datetime.now(UTC)
mock.reinforcement_count = 1
mock.embedding = None
mock.created_at = datetime.now(UTC)
mock.updated_at = datetime.now(UTC)
return mock
class TestFactConflict:
"""Tests for FactConflict dataclass."""
def test_to_dict(self) -> None:
"""Test converting conflict to dictionary."""
conflict = FactConflict(
fact_a_id=uuid4(),
fact_b_id=uuid4(),
conflict_type="contradiction",
description="Test conflict",
suggested_resolution="Keep higher confidence",
)
result = conflict.to_dict()
assert "fact_a_id" in result
assert "fact_b_id" in result
assert result["conflict_type"] == "contradiction"
assert result["description"] == "Test conflict"
class TestVerificationResult:
"""Tests for VerificationResult dataclass."""
def test_default_values(self) -> None:
"""Test default values."""
result = VerificationResult(is_valid=True)
assert result.is_valid is True
assert result.confidence_adjustment == 0.0
assert result.conflicts == []
assert result.supporting_facts == []
assert result.messages == []
class TestFactVerifier:
"""Tests for FactVerifier class."""
@pytest.fixture
def mock_session(self) -> AsyncMock:
"""Create a mock database session."""
session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
session.execute.return_value = mock_result
return session
@pytest.fixture
def verifier(self, mock_session: AsyncMock) -> FactVerifier:
"""Create a fact verifier."""
return FactVerifier(session=mock_session)
@pytest.mark.asyncio
async def test_verify_fact_valid(
self,
verifier: FactVerifier,
) -> None:
"""Test verifying a valid fact with no conflicts."""
result = await verifier.verify_fact(
subject="Python",
predicate="is_a",
obj="programming language",
)
assert result.is_valid is True
assert len(result.conflicts) == 0
@pytest.mark.asyncio
async def test_verify_fact_with_support(
self,
verifier: FactVerifier,
mock_session: AsyncMock,
) -> None:
"""Test verifying a fact with supporting facts."""
# Mock finding supporting facts
supporting = [create_mock_fact_model()]
# First query: contradictions (empty)
contradiction_result = MagicMock()
contradiction_result.scalars.return_value.all.return_value = []
# Second query: supporting facts
support_result = MagicMock()
support_result.scalars.return_value.all.return_value = supporting
mock_session.execute.side_effect = [contradiction_result, support_result]
result = await verifier.verify_fact(
subject="Python",
predicate="uses",
obj="dynamic typing",
)
assert result.is_valid is True
assert len(result.supporting_facts) >= 1
assert result.confidence_adjustment > 0
@pytest.mark.asyncio
async def test_verify_fact_with_contradiction(
self,
verifier: FactVerifier,
mock_session: AsyncMock,
) -> None:
"""Test verifying a fact with contradictions."""
# Mock finding contradicting fact
contradicting = create_mock_fact_model(
subject="Python",
predicate="does_not_use",
obj="static typing",
)
contradiction_result = MagicMock()
contradiction_result.scalars.return_value.all.return_value = [contradicting]
support_result = MagicMock()
support_result.scalars.return_value.all.return_value = []
mock_session.execute.side_effect = [contradiction_result, support_result]
result = await verifier.verify_fact(
subject="Python",
predicate="uses",
obj="static typing",
)
assert result.is_valid is False
assert len(result.conflicts) >= 1
assert result.confidence_adjustment < 0
def test_get_opposite_predicates(
self,
verifier: FactVerifier,
) -> None:
"""Test getting opposite predicates."""
opposites = verifier._get_opposite_predicates("uses")
assert "does_not_use" in opposites
def test_get_opposite_predicates_unknown(
self,
verifier: FactVerifier,
) -> None:
"""Test getting opposites for unknown predicate."""
opposites = verifier._get_opposite_predicates("unknown_predicate")
assert opposites == []
@pytest.mark.asyncio
async def test_find_all_conflicts_empty(
self,
verifier: FactVerifier,
mock_session: AsyncMock,
) -> None:
"""Test finding all conflicts in empty fact base."""
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock_session.execute.return_value = mock_result
conflicts = await verifier.find_all_conflicts()
assert conflicts == []
@pytest.mark.asyncio
async def test_find_all_conflicts_no_conflicts(
self,
verifier: FactVerifier,
mock_session: AsyncMock,
) -> None:
"""Test finding conflicts when there are none."""
# Two facts with different subjects
fact1 = create_mock_fact_model(subject="Python", predicate="uses")
fact2 = create_mock_fact_model(subject="JavaScript", predicate="uses")
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [fact1, fact2]
mock_session.execute.return_value = mock_result
conflicts = await verifier.find_all_conflicts()
assert conflicts == []
@pytest.mark.asyncio
async def test_find_all_conflicts_with_contradiction(
self,
verifier: FactVerifier,
mock_session: AsyncMock,
) -> None:
"""Test finding contradicting facts."""
# Two contradicting facts
fact1 = create_mock_fact_model(
subject="Python",
predicate="best_practice",
obj="Use type hints",
)
fact2 = create_mock_fact_model(
subject="Python",
predicate="anti_pattern",
obj="Use type hints",
)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [fact1, fact2]
mock_session.execute.return_value = mock_result
conflicts = await verifier.find_all_conflicts()
assert len(conflicts) == 1
assert conflicts[0].conflict_type == "contradiction"
@pytest.mark.asyncio
async def test_get_fact_reliability_score_not_found(
self,
verifier: FactVerifier,
mock_session: AsyncMock,
) -> None:
"""Test reliability score for non-existent fact."""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
score = await verifier.get_fact_reliability_score(uuid4())
assert score == 0.0
@pytest.mark.asyncio
async def test_get_fact_reliability_score(
self,
verifier: FactVerifier,
mock_session: AsyncMock,
) -> None:
"""Test calculating reliability score."""
fact = create_mock_fact_model(confidence=0.8)
fact.reinforcement_count = 5
# Query 1: Get fact
fact_result = MagicMock()
fact_result.scalar_one_or_none.return_value = fact
# Query 2: Supporting facts
support_result = MagicMock()
support_result.scalars.return_value.all.return_value = []
# Query 3: Contradictions
conflict_result = MagicMock()
conflict_result.scalars.return_value.all.return_value = []
mock_session.execute.side_effect = [
fact_result,
support_result,
conflict_result,
]
score = await verifier.get_fact_reliability_score(fact.id)
# Score should be >= confidence (0.8) due to reinforcement bonus
assert score >= 0.8
assert score <= 1.0