# app/services/memory/semantic/memory.py """ Semantic Memory Implementation. Provides fact storage and retrieval using subject-predicate-object triples. Supports semantic search, confidence scoring, and fact reinforcement. """ import logging import time from datetime import UTC, datetime from typing import Any from uuid import UUID from sqlalchemy import and_, desc, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from app.models.memory.fact import Fact as FactModel from app.services.memory.config import get_memory_settings from app.services.memory.types import Episode, Fact, FactCreate, RetrievalResult logger = logging.getLogger(__name__) def _model_to_fact(model: FactModel) -> Fact: """Convert SQLAlchemy model to Fact dataclass.""" # SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime # they return actual values. We use type: ignore to handle this mismatch. return Fact( id=model.id, # type: ignore[arg-type] project_id=model.project_id, # type: ignore[arg-type] subject=model.subject, # type: ignore[arg-type] predicate=model.predicate, # type: ignore[arg-type] object=model.object, # type: ignore[arg-type] confidence=model.confidence, # type: ignore[arg-type] source_episode_ids=model.source_episode_ids or [], # type: ignore[arg-type] first_learned=model.first_learned, # type: ignore[arg-type] last_reinforced=model.last_reinforced, # type: ignore[arg-type] reinforcement_count=model.reinforcement_count, # type: ignore[arg-type] embedding=None, # Don't expose raw embedding created_at=model.created_at, # type: ignore[arg-type] updated_at=model.updated_at, # type: ignore[arg-type] ) class SemanticMemory: """ Semantic Memory Service. Provides fact storage and retrieval: - Store facts as subject-predicate-object triples - Semantic search over facts - Entity-based retrieval - Confidence scoring and decay - Fact reinforcement on repeated learning - Conflict resolution Performance target: <100ms P95 for retrieval """ def __init__( self, session: AsyncSession, embedding_generator: Any | None = None, ) -> None: """ Initialize semantic memory. Args: session: Database session embedding_generator: Optional embedding generator for semantic search """ self._session = session self._embedding_generator = embedding_generator self._settings = get_memory_settings() @classmethod async def create( cls, session: AsyncSession, embedding_generator: Any | None = None, ) -> "SemanticMemory": """ Factory method to create SemanticMemory. Args: session: Database session embedding_generator: Optional embedding generator Returns: Configured SemanticMemory instance """ return cls(session=session, embedding_generator=embedding_generator) # ========================================================================= # Fact Storage # ========================================================================= async def store_fact(self, fact: FactCreate) -> Fact: """ Store a new fact or reinforce an existing one. If a fact with the same triple (subject, predicate, object) exists in the same scope, it will be reinforced instead of duplicated. Args: fact: Fact data to store Returns: The created or reinforced fact """ # Check for existing fact with same triple existing = await self._find_existing_fact( project_id=fact.project_id, subject=fact.subject, predicate=fact.predicate, object=fact.object, ) if existing is not None: # Reinforce existing fact return await self.reinforce_fact( existing.id, # type: ignore[arg-type] source_episode_ids=fact.source_episode_ids, ) # Create new fact now = datetime.now(UTC) # Generate embedding if possible embedding = None if self._embedding_generator is not None: embedding_text = self._create_embedding_text(fact) embedding = await self._embedding_generator.generate(embedding_text) model = FactModel( project_id=fact.project_id, subject=fact.subject, predicate=fact.predicate, object=fact.object, confidence=fact.confidence, source_episode_ids=fact.source_episode_ids, first_learned=now, last_reinforced=now, reinforcement_count=1, embedding=embedding, ) self._session.add(model) await self._session.flush() await self._session.refresh(model) logger.info( f"Stored new fact: {fact.subject} - {fact.predicate} - {fact.object[:50]}..." ) return _model_to_fact(model) async def _find_existing_fact( self, project_id: UUID | None, subject: str, predicate: str, object: str, ) -> FactModel | None: """Find an existing fact with the same triple in the same scope.""" query = select(FactModel).where( and_( FactModel.subject == subject, FactModel.predicate == predicate, FactModel.object == object, ) ) if project_id is not None: query = query.where(FactModel.project_id == project_id) else: query = query.where(FactModel.project_id.is_(None)) result = await self._session.execute(query) return result.scalar_one_or_none() def _create_embedding_text(self, fact: FactCreate) -> str: """Create text for embedding from fact data.""" return f"{fact.subject} {fact.predicate} {fact.object}" # ========================================================================= # Fact Retrieval # ========================================================================= async def search_facts( self, query: str, project_id: UUID | None = None, limit: int = 10, min_confidence: float | None = None, ) -> list[Fact]: """ Search for facts semantically similar to the query. Args: query: Search query project_id: Optional project to search within limit: Maximum results min_confidence: Optional minimum confidence filter Returns: List of matching facts """ result = await self._search_facts_with_metadata( query=query, project_id=project_id, limit=limit, min_confidence=min_confidence, ) return result.items async def _search_facts_with_metadata( self, query: str, project_id: UUID | None = None, limit: int = 10, min_confidence: float | None = None, ) -> RetrievalResult[Fact]: """Search facts with full result metadata.""" start_time = time.perf_counter() min_conf = min_confidence or self._settings.semantic_min_confidence # Build base query stmt = ( select(FactModel) .where(FactModel.confidence >= min_conf) .order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced)) .limit(limit) ) # Apply project filter if project_id is not None: # Include both project-specific and global facts stmt = stmt.where( or_( FactModel.project_id == project_id, FactModel.project_id.is_(None), ) ) # TODO: Implement proper vector similarity search when pgvector is integrated # For now, do text-based search on subject/predicate/object search_terms = query.lower().split() if search_terms: conditions = [] for term in search_terms[:5]: # Limit to 5 terms term_pattern = f"%{term}%" conditions.append( or_( FactModel.subject.ilike(term_pattern), FactModel.predicate.ilike(term_pattern), FactModel.object.ilike(term_pattern), ) ) if conditions: stmt = stmt.where(or_(*conditions)) result = await self._session.execute(stmt) models = list(result.scalars().all()) latency_ms = (time.perf_counter() - start_time) * 1000 return RetrievalResult( items=[_model_to_fact(m) for m in models], total_count=len(models), query=query, retrieval_type="semantic", latency_ms=latency_ms, metadata={"min_confidence": min_conf}, ) async def get_by_entity( self, entity: str, project_id: UUID | None = None, limit: int = 20, ) -> list[Fact]: """ Get facts related to an entity (as subject or object). Args: entity: Entity to search for project_id: Optional project to search within limit: Maximum results Returns: List of facts mentioning the entity """ start_time = time.perf_counter() stmt = ( select(FactModel) .where( or_( FactModel.subject.ilike(f"%{entity}%"), FactModel.object.ilike(f"%{entity}%"), ) ) .order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced)) .limit(limit) ) if project_id is not None: stmt = stmt.where( or_( FactModel.project_id == project_id, FactModel.project_id.is_(None), ) ) result = await self._session.execute(stmt) models = list(result.scalars().all()) latency_ms = (time.perf_counter() - start_time) * 1000 logger.debug( f"get_by_entity({entity}) returned {len(models)} facts in {latency_ms:.1f}ms" ) return [_model_to_fact(m) for m in models] async def get_by_subject( self, subject: str, project_id: UUID | None = None, predicate: str | None = None, limit: int = 20, ) -> list[Fact]: """ Get facts with a specific subject. Args: subject: Subject to search for project_id: Optional project to search within predicate: Optional predicate filter limit: Maximum results Returns: List of facts with matching subject """ stmt = ( select(FactModel) .where(FactModel.subject == subject) .order_by(desc(FactModel.confidence)) .limit(limit) ) if predicate is not None: stmt = stmt.where(FactModel.predicate == predicate) if project_id is not None: stmt = stmt.where( or_( FactModel.project_id == project_id, FactModel.project_id.is_(None), ) ) result = await self._session.execute(stmt) models = list(result.scalars().all()) return [_model_to_fact(m) for m in models] async def get_by_id(self, fact_id: UUID) -> Fact | None: """Get a fact by ID.""" query = select(FactModel).where(FactModel.id == fact_id) result = await self._session.execute(query) model = result.scalar_one_or_none() return _model_to_fact(model) if model else None # ========================================================================= # Fact Reinforcement # ========================================================================= async def reinforce_fact( self, fact_id: UUID, confidence_boost: float = 0.1, source_episode_ids: list[UUID] | None = None, ) -> Fact: """ Reinforce a fact, increasing its confidence. Args: fact_id: Fact to reinforce confidence_boost: Amount to increase confidence (default 0.1) source_episode_ids: Additional source episodes Returns: Updated fact Raises: ValueError: If fact not found """ query = select(FactModel).where(FactModel.id == fact_id) result = await self._session.execute(query) model = result.scalar_one_or_none() if model is None: raise ValueError(f"Fact not found: {fact_id}") # Calculate new confidence (max 1.0) current_confidence: float = model.confidence # type: ignore[assignment] new_confidence = min(1.0, current_confidence + confidence_boost) # Merge source episode IDs current_sources: list[UUID] = model.source_episode_ids or [] # type: ignore[assignment] if source_episode_ids: # Add new sources, avoiding duplicates new_sources = list(set(current_sources + source_episode_ids)) else: new_sources = current_sources now = datetime.now(UTC) stmt = ( update(FactModel) .where(FactModel.id == fact_id) .values( confidence=new_confidence, source_episode_ids=new_sources, last_reinforced=now, reinforcement_count=FactModel.reinforcement_count + 1, updated_at=now, ) .returning(FactModel) ) result = await self._session.execute(stmt) updated_model = result.scalar_one() await self._session.flush() logger.info( f"Reinforced fact {fact_id}: confidence {current_confidence:.2f} -> {new_confidence:.2f}" ) return _model_to_fact(updated_model) async def deprecate_fact( self, fact_id: UUID, reason: str, new_confidence: float = 0.0, ) -> Fact | None: """ Deprecate a fact by lowering its confidence. Args: fact_id: Fact to deprecate reason: Reason for deprecation new_confidence: New confidence level (default 0.0) Returns: Updated fact or None if not found """ query = select(FactModel).where(FactModel.id == fact_id) result = await self._session.execute(query) model = result.scalar_one_or_none() if model is None: return None now = datetime.now(UTC) stmt = ( update(FactModel) .where(FactModel.id == fact_id) .values( confidence=max(0.0, new_confidence), updated_at=now, ) .returning(FactModel) ) result = await self._session.execute(stmt) updated_model = result.scalar_one_or_none() await self._session.flush() logger.info(f"Deprecated fact {fact_id}: {reason}") return _model_to_fact(updated_model) if updated_model else None # ========================================================================= # Fact Extraction from Episodes # ========================================================================= async def extract_facts_from_episode( self, episode: Episode, ) -> list[Fact]: """ Extract facts from an episode. This is a placeholder for LLM-based fact extraction. In production, this would call an LLM to analyze the episode and extract subject-predicate-object triples. Args: episode: Episode to extract facts from Returns: List of extracted facts """ # For now, extract basic facts from lessons learned extracted_facts: list[Fact] = [] for lesson in episode.lessons_learned: if len(lesson) > 10: # Skip very short lessons fact_create = FactCreate( subject=episode.task_type, predicate="lesson_learned", object=lesson, confidence=0.7, # Lessons start with moderate confidence project_id=episode.project_id, source_episode_ids=[episode.id], ) fact = await self.store_fact(fact_create) extracted_facts.append(fact) logger.debug( f"Extracted {len(extracted_facts)} facts from episode {episode.id}" ) return extracted_facts # ========================================================================= # Conflict Resolution # ========================================================================= async def resolve_conflict( self, fact_ids: list[UUID], keep_fact_id: UUID | None = None, ) -> Fact | None: """ Resolve a conflict between multiple facts. If keep_fact_id is specified, that fact is kept and others are deprecated. Otherwise, the fact with highest confidence is kept. Args: fact_ids: IDs of conflicting facts keep_fact_id: Optional ID of fact to keep Returns: The winning fact, or None if no facts found """ if not fact_ids: return None # Load all facts query = select(FactModel).where(FactModel.id.in_(fact_ids)) result = await self._session.execute(query) models = list(result.scalars().all()) if not models: return None # Determine winner if keep_fact_id is not None: winner = next((m for m in models if m.id == keep_fact_id), None) if winner is None: # Fallback to highest confidence winner = max(models, key=lambda m: m.confidence) else: # Keep the fact with highest confidence winner = max(models, key=lambda m: m.confidence) # Deprecate losers for model in models: if model.id != winner.id: await self.deprecate_fact( model.id, # type: ignore[arg-type] reason=f"Conflict resolution: superseded by {winner.id}", ) logger.info( f"Resolved conflict between {len(fact_ids)} facts, keeping {winner.id}" ) return _model_to_fact(winner) # ========================================================================= # Confidence Decay # ========================================================================= async def apply_confidence_decay( self, project_id: UUID | None = None, decay_factor: float = 0.01, ) -> int: """ Apply confidence decay to facts that haven't been reinforced recently. Args: project_id: Optional project to apply decay to decay_factor: Decay factor per day (default 0.01) Returns: Number of facts affected """ now = datetime.now(UTC) decay_days = self._settings.semantic_confidence_decay_days min_conf = self._settings.semantic_min_confidence # Calculate cutoff date from datetime import timedelta cutoff = now - timedelta(days=decay_days) # Find facts needing decay query = select(FactModel).where( and_( FactModel.last_reinforced < cutoff, FactModel.confidence > min_conf, ) ) if project_id is not None: query = query.where(FactModel.project_id == project_id) result = await self._session.execute(query) models = list(result.scalars().all()) # Apply decay updated_count = 0 for model in models: # Calculate days since last reinforcement days_since: float = (now - model.last_reinforced).days # Calculate decay: exponential decay based on days decay = decay_factor * (days_since - decay_days) new_confidence = max(min_conf, model.confidence - decay) if new_confidence != model.confidence: await self._session.execute( update(FactModel) .where(FactModel.id == model.id) .values(confidence=new_confidence, updated_at=now) ) updated_count += 1 await self._session.flush() logger.info(f"Applied confidence decay to {updated_count} facts") return updated_count # ========================================================================= # Statistics # ========================================================================= async def get_stats(self, project_id: UUID | None = None) -> dict[str, Any]: """ Get statistics about semantic memory. Args: project_id: Optional project to get stats for Returns: Dictionary with statistics """ # Get all facts for this scope query = select(FactModel) if project_id is not None: query = query.where( or_( FactModel.project_id == project_id, FactModel.project_id.is_(None), ) ) result = await self._session.execute(query) models = list(result.scalars().all()) if not models: return { "total_facts": 0, "avg_confidence": 0.0, "avg_reinforcement_count": 0.0, "high_confidence_count": 0, "low_confidence_count": 0, } confidences = [m.confidence for m in models] reinforcements = [m.reinforcement_count for m in models] return { "total_facts": len(models), "avg_confidence": sum(confidences) / len(confidences), "avg_reinforcement_count": sum(reinforcements) / len(reinforcements), "high_confidence_count": sum(1 for c in confidences if c >= 0.8), "low_confidence_count": sum(1 for c in confidences if c < 0.5), } async def count(self, project_id: UUID | None = None) -> int: """ Count facts in scope. Args: project_id: Optional project to count for Returns: Number of facts """ query = select(FactModel) if project_id is not None: query = query.where( or_( FactModel.project_id == project_id, FactModel.project_id.is_(None), ) ) result = await self._session.execute(query) return len(list(result.scalars().all())) async def delete(self, fact_id: UUID) -> bool: """ Delete a fact. Args: fact_id: Fact to delete Returns: True if deleted, False if not found """ query = select(FactModel).where(FactModel.id == fact_id) result = await self._session.execute(query) model = result.scalar_one_or_none() if model is None: return False await self._session.delete(model) await self._session.flush() logger.info(f"Deleted fact {fact_id}") return True