diff --git a/backend/app/services/memory/procedural/memory.py b/backend/app/services/memory/procedural/memory.py index dca4eeb..997fdd5 100644 --- a/backend/app/services/memory/procedural/memory.py +++ b/backend/app/services/memory/procedural/memory.py @@ -22,6 +22,25 @@ from app.services.memory.types import Procedure, ProcedureCreate, RetrievalResul logger = logging.getLogger(__name__) +def _escape_like_pattern(pattern: str) -> str: + """ + Escape SQL LIKE/ILIKE special characters to prevent pattern injection. + + Characters escaped: + - % (matches zero or more characters) + - _ (matches exactly one character) + - \\ (escape character itself) + + Args: + pattern: Raw search pattern from user input + + Returns: + Escaped pattern safe for use in LIKE/ILIKE queries + """ + # Escape backslash first, then the wildcards + return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + def _model_to_procedure(model: ProcedureModel) -> Procedure: """Convert SQLAlchemy model to Procedure dataclass.""" return Procedure( @@ -320,7 +339,9 @@ class ProceduralMemory: if search_terms: conditions = [] for term in search_terms: - term_pattern = f"%{term}%" + # Escape SQL wildcards to prevent pattern injection + escaped_term = _escape_like_pattern(term) + term_pattern = f"%{escaped_term}%" conditions.append( or_( ProcedureModel.trigger_pattern.ilike(term_pattern), @@ -368,6 +389,10 @@ class ProceduralMemory: Returns: Best matching procedure or None """ + # Escape SQL wildcards to prevent pattern injection + escaped_task_type = _escape_like_pattern(task_type) + task_type_pattern = f"%{escaped_task_type}%" + # Build query for procedures matching task type stmt = ( select(ProcedureModel) @@ -376,8 +401,8 @@ class ProceduralMemory: (ProcedureModel.success_count + ProcedureModel.failure_count) >= min_uses, or_( - ProcedureModel.trigger_pattern.ilike(f"%{task_type}%"), - ProcedureModel.name.ilike(f"%{task_type}%"), + ProcedureModel.trigger_pattern.ilike(task_type_pattern), + ProcedureModel.name.ilike(task_type_pattern), ), ) ) diff --git a/backend/app/services/memory/semantic/memory.py b/backend/app/services/memory/semantic/memory.py index bca103b..2dde3d0 100644 --- a/backend/app/services/memory/semantic/memory.py +++ b/backend/app/services/memory/semantic/memory.py @@ -22,6 +22,25 @@ from app.services.memory.types import Episode, Fact, FactCreate, RetrievalResult logger = logging.getLogger(__name__) +def _escape_like_pattern(pattern: str) -> str: + """ + Escape SQL LIKE/ILIKE special characters to prevent pattern injection. + + Characters escaped: + - % (matches zero or more characters) + - _ (matches exactly one character) + - \\ (escape character itself) + + Args: + pattern: Raw search pattern from user input + + Returns: + Escaped pattern safe for use in LIKE/ILIKE queries + """ + # Escape backslash first, then the wildcards + return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + def _model_to_fact(model: FactModel) -> Fact: """Convert SQLAlchemy model to Fact dataclass.""" # SQLAlchemy Column types are inferred as Column[T] by mypy, but at runtime @@ -251,7 +270,9 @@ class SemanticMemory: if search_terms: conditions = [] for term in search_terms[:5]: # Limit to 5 terms - term_pattern = f"%{term}%" + # Escape SQL wildcards to prevent pattern injection + escaped_term = _escape_like_pattern(term) + term_pattern = f"%{escaped_term}%" conditions.append( or_( FactModel.subject.ilike(term_pattern), @@ -295,12 +316,16 @@ class SemanticMemory: """ start_time = time.perf_counter() + # Escape SQL wildcards to prevent pattern injection + escaped_entity = _escape_like_pattern(entity) + entity_pattern = f"%{escaped_entity}%" + stmt = ( select(FactModel) .where( or_( - FactModel.subject.ilike(f"%{entity}%"), - FactModel.object.ilike(f"%{entity}%"), + FactModel.subject.ilike(entity_pattern), + FactModel.object.ilike(entity_pattern), ) ) .order_by(desc(FactModel.confidence), desc(FactModel.last_reinforced))