From 22ecb5e989a5684c0ae26d957aacc3ab5320089c Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Sun, 4 Jan 2026 02:07:39 +0100 Subject: [PATCH] feat(context): Phase 1 - Foundation types, config and exceptions (#79) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the foundation for Context Management Engine: Types (backend/app/services/context/types/): - BaseContext: Abstract base with ID, content, priority, scoring - SystemContext: System prompts, personas, instructions - KnowledgeContext: RAG results from Knowledge Base MCP - ConversationContext: Chat history with role support - TaskContext: Task/issue context with acceptance criteria - ToolContext: Tool definitions and execution results - AssembledContext: Final assembled context result Configuration (config.py): - Token budget allocation (system 5%, task 10%, knowledge 40%, etc.) - Scoring weights (relevance 50%, recency 30%, priority 20%) - Cache settings (TTL, prefix) - Performance settings (max assembly time, parallel scoring) - Environment variable overrides with CTX_ prefix Exceptions (exceptions.py): - ContextError: Base exception - BudgetExceededError: Token budget violations - TokenCountError: Token counting failures - CompressionError: Compression failures - AssemblyTimeoutError: Assembly timeout - ScoringError, FormattingError, CacheError - ContextNotFoundError, InvalidContextError All 86 tests pass. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/app/services/context/__init__.py | 105 ++++ .../app/services/context/adapters/__init__.py | 5 + .../app/services/context/assembly/__init__.py | 5 + .../app/services/context/budget/__init__.py | 5 + .../app/services/context/cache/__init__.py | 5 + .../services/context/compression/__init__.py | 5 + backend/app/services/context/config.py | 328 ++++++++++ backend/app/services/context/exceptions.py | 354 +++++++++++ .../context/prioritization/__init__.py | 5 + .../app/services/context/scoring/__init__.py | 5 + .../app/services/context/types/__init__.py | 49 ++ backend/app/services/context/types/base.py | 320 ++++++++++ .../services/context/types/conversation.py | 182 ++++++ .../app/services/context/types/knowledge.py | 143 +++++ backend/app/services/context/types/system.py | 138 +++++ backend/app/services/context/types/task.py | 195 ++++++ backend/app/services/context/types/tool.py | 207 +++++++ backend/tests/services/context/__init__.py | 1 + backend/tests/services/context/test_config.py | 243 ++++++++ .../tests/services/context/test_exceptions.py | 252 ++++++++ backend/tests/services/context/test_types.py | 579 ++++++++++++++++++ 21 files changed, 3131 insertions(+) create mode 100644 backend/app/services/context/__init__.py create mode 100644 backend/app/services/context/adapters/__init__.py create mode 100644 backend/app/services/context/assembly/__init__.py create mode 100644 backend/app/services/context/budget/__init__.py create mode 100644 backend/app/services/context/cache/__init__.py create mode 100644 backend/app/services/context/compression/__init__.py create mode 100644 backend/app/services/context/config.py create mode 100644 backend/app/services/context/exceptions.py create mode 100644 backend/app/services/context/prioritization/__init__.py create mode 100644 backend/app/services/context/scoring/__init__.py create mode 100644 backend/app/services/context/types/__init__.py create mode 100644 backend/app/services/context/types/base.py create mode 100644 backend/app/services/context/types/conversation.py create mode 100644 backend/app/services/context/types/knowledge.py create mode 100644 backend/app/services/context/types/system.py create mode 100644 backend/app/services/context/types/task.py create mode 100644 backend/app/services/context/types/tool.py create mode 100644 backend/tests/services/context/__init__.py create mode 100644 backend/tests/services/context/test_config.py create mode 100644 backend/tests/services/context/test_exceptions.py create mode 100644 backend/tests/services/context/test_types.py diff --git a/backend/app/services/context/__init__.py b/backend/app/services/context/__init__.py new file mode 100644 index 0000000..40034e0 --- /dev/null +++ b/backend/app/services/context/__init__.py @@ -0,0 +1,105 @@ +""" +Context Management Engine + +Sophisticated context assembly and optimization for LLM requests. +Provides intelligent context selection, token budget management, +and model-specific formatting. + +Usage: + from app.services.context import ( + ContextSettings, + get_context_settings, + SystemContext, + KnowledgeContext, + ConversationContext, + TaskContext, + ToolContext, + ) + + # Get settings + settings = get_context_settings() + + # Create context instances + system_ctx = SystemContext.create_persona( + name="Code Assistant", + description="You are a helpful code assistant.", + capabilities=["Write code", "Debug issues"], + ) +""" + +# Configuration +from .config import ( + ContextSettings, + get_context_settings, + get_default_settings, + reset_context_settings, +) + +# Exceptions +from .exceptions import ( + AssemblyTimeoutError, + BudgetExceededError, + CacheError, + CompressionError, + ContextError, + ContextNotFoundError, + FormattingError, + InvalidContextError, + ScoringError, + TokenCountError, +) + +# Types +from .types import ( + AssembledContext, + BaseContext, + ContextPriority, + ContextType, + ConversationContext, + KnowledgeContext, + MessageRole, + SystemContext, + TaskComplexity, + TaskContext, + TaskStatus, + ToolContext, + ToolResultStatus, +) + +__all__ = [ + # Configuration + "ContextSettings", + "get_context_settings", + "get_default_settings", + "reset_context_settings", + # Exceptions + "AssemblyTimeoutError", + "BudgetExceededError", + "CacheError", + "CompressionError", + "ContextError", + "ContextNotFoundError", + "FormattingError", + "InvalidContextError", + "ScoringError", + "TokenCountError", + # Types - Base + "AssembledContext", + "BaseContext", + "ContextPriority", + "ContextType", + # Types - Conversation + "ConversationContext", + "MessageRole", + # Types - Knowledge + "KnowledgeContext", + # Types - System + "SystemContext", + # Types - Task + "TaskComplexity", + "TaskContext", + "TaskStatus", + # Types - Tool + "ToolContext", + "ToolResultStatus", +] diff --git a/backend/app/services/context/adapters/__init__.py b/backend/app/services/context/adapters/__init__.py new file mode 100644 index 0000000..fbf30a4 --- /dev/null +++ b/backend/app/services/context/adapters/__init__.py @@ -0,0 +1,5 @@ +""" +Model Adapters Module. + +Provides model-specific context formatting. +""" diff --git a/backend/app/services/context/assembly/__init__.py b/backend/app/services/context/assembly/__init__.py new file mode 100644 index 0000000..ae869ea --- /dev/null +++ b/backend/app/services/context/assembly/__init__.py @@ -0,0 +1,5 @@ +""" +Context Assembly Module. + +Provides the assembly pipeline and formatting. +""" diff --git a/backend/app/services/context/budget/__init__.py b/backend/app/services/context/budget/__init__.py new file mode 100644 index 0000000..f3a675b --- /dev/null +++ b/backend/app/services/context/budget/__init__.py @@ -0,0 +1,5 @@ +""" +Token Budget Management Module. + +Provides token counting and budget allocation. +""" diff --git a/backend/app/services/context/cache/__init__.py b/backend/app/services/context/cache/__init__.py new file mode 100644 index 0000000..075e014 --- /dev/null +++ b/backend/app/services/context/cache/__init__.py @@ -0,0 +1,5 @@ +""" +Context Cache Module. + +Provides Redis-based caching for assembled contexts. +""" diff --git a/backend/app/services/context/compression/__init__.py b/backend/app/services/context/compression/__init__.py new file mode 100644 index 0000000..28cb5e9 --- /dev/null +++ b/backend/app/services/context/compression/__init__.py @@ -0,0 +1,5 @@ +""" +Context Compression Module. + +Provides truncation and compression strategies. +""" diff --git a/backend/app/services/context/config.py b/backend/app/services/context/config.py new file mode 100644 index 0000000..7b82cd5 --- /dev/null +++ b/backend/app/services/context/config.py @@ -0,0 +1,328 @@ +""" +Context Management Engine Configuration. + +Provides Pydantic settings for context assembly, +token budget allocation, and caching. +""" + +import threading +from functools import lru_cache +from typing import Any + +from pydantic import Field, field_validator, model_validator +from pydantic_settings import BaseSettings + + +class ContextSettings(BaseSettings): + """ + Configuration for the Context Management Engine. + + All settings can be overridden via environment variables + with the CTX_ prefix. + """ + + # Budget allocation percentages (must sum to 1.0) + budget_system: float = Field( + default=0.05, + ge=0.0, + le=1.0, + description="Percentage of budget for system prompts (5%)", + ) + budget_task: float = Field( + default=0.10, + ge=0.0, + le=1.0, + description="Percentage of budget for task context (10%)", + ) + budget_knowledge: float = Field( + default=0.40, + ge=0.0, + le=1.0, + description="Percentage of budget for RAG/knowledge (40%)", + ) + budget_conversation: float = Field( + default=0.20, + ge=0.0, + le=1.0, + description="Percentage of budget for conversation history (20%)", + ) + budget_tools: float = Field( + default=0.05, + ge=0.0, + le=1.0, + description="Percentage of budget for tool descriptions (5%)", + ) + budget_response: float = Field( + default=0.15, + ge=0.0, + le=1.0, + description="Percentage reserved for response (15%)", + ) + budget_buffer: float = Field( + default=0.05, + ge=0.0, + le=1.0, + description="Percentage buffer for safety margin (5%)", + ) + + # Scoring weights + scoring_relevance_weight: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Weight for relevance scoring", + ) + scoring_recency_weight: float = Field( + default=0.3, + ge=0.0, + le=1.0, + description="Weight for recency scoring", + ) + scoring_priority_weight: float = Field( + default=0.2, + ge=0.0, + le=1.0, + description="Weight for priority scoring", + ) + + # Recency decay settings + recency_decay_hours: float = Field( + default=24.0, + gt=0.0, + description="Hours until recency score decays to 50%", + ) + recency_max_age_hours: float = Field( + default=168.0, + gt=0.0, + description="Hours until context is considered stale (7 days)", + ) + + # Compression settings + compression_threshold: float = Field( + default=0.8, + ge=0.0, + le=1.0, + description="Compress when budget usage exceeds this percentage", + ) + truncation_suffix: str = Field( + default="... [truncated]", + description="Suffix to add when truncating content", + ) + summary_model_group: str = Field( + default="fast", + description="Model group to use for summarization", + ) + + # Caching settings + cache_enabled: bool = Field( + default=True, + description="Enable Redis caching for assembled contexts", + ) + cache_ttl_seconds: int = Field( + default=3600, + ge=60, + le=86400, + description="Cache TTL in seconds (1 hour default, max 24 hours)", + ) + cache_prefix: str = Field( + default="ctx", + description="Redis key prefix for context cache", + ) + + # Performance settings + max_assembly_time_ms: int = Field( + default=100, + ge=10, + le=5000, + description="Maximum time for context assembly in milliseconds", + ) + parallel_scoring: bool = Field( + default=True, + description="Score contexts in parallel for better performance", + ) + max_parallel_scores: int = Field( + default=10, + ge=1, + le=50, + description="Maximum number of contexts to score in parallel", + ) + + # Knowledge retrieval settings + knowledge_search_type: str = Field( + default="hybrid", + description="Default search type for knowledge retrieval", + ) + knowledge_max_results: int = Field( + default=10, + ge=1, + le=50, + description="Maximum knowledge chunks to retrieve", + ) + knowledge_min_score: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Minimum relevance score for knowledge", + ) + + # Conversation history settings + conversation_max_turns: int = Field( + default=20, + ge=1, + le=100, + description="Maximum conversation turns to include", + ) + conversation_recent_priority: bool = Field( + default=True, + description="Prioritize recent conversation turns", + ) + + @field_validator("knowledge_search_type") + @classmethod + def validate_search_type(cls, v: str) -> str: + """Validate search type is valid.""" + valid_types = {"semantic", "keyword", "hybrid"} + if v not in valid_types: + raise ValueError(f"search_type must be one of: {valid_types}") + return v + + @model_validator(mode="after") + def validate_budget_allocation(self) -> "ContextSettings": + """Validate that budget percentages sum to 1.0.""" + total = ( + self.budget_system + + self.budget_task + + self.budget_knowledge + + self.budget_conversation + + self.budget_tools + + self.budget_response + + self.budget_buffer + ) + # Allow small floating point error + if abs(total - 1.0) > 0.001: + raise ValueError( + f"Budget percentages must sum to 1.0, got {total:.3f}. " + f"Current allocation: system={self.budget_system}, task={self.budget_task}, " + f"knowledge={self.budget_knowledge}, conversation={self.budget_conversation}, " + f"tools={self.budget_tools}, response={self.budget_response}, buffer={self.budget_buffer}" + ) + return self + + @model_validator(mode="after") + def validate_scoring_weights(self) -> "ContextSettings": + """Validate that scoring weights sum to 1.0.""" + total = ( + self.scoring_relevance_weight + + self.scoring_recency_weight + + self.scoring_priority_weight + ) + # Allow small floating point error + if abs(total - 1.0) > 0.001: + raise ValueError( + f"Scoring weights must sum to 1.0, got {total:.3f}. " + f"Current weights: relevance={self.scoring_relevance_weight}, " + f"recency={self.scoring_recency_weight}, priority={self.scoring_priority_weight}" + ) + return self + + def get_budget_allocation(self) -> dict[str, float]: + """Get budget allocation as a dictionary.""" + return { + "system": self.budget_system, + "task": self.budget_task, + "knowledge": self.budget_knowledge, + "conversation": self.budget_conversation, + "tools": self.budget_tools, + "response": self.budget_response, + "buffer": self.budget_buffer, + } + + def get_scoring_weights(self) -> dict[str, float]: + """Get scoring weights as a dictionary.""" + return { + "relevance": self.scoring_relevance_weight, + "recency": self.scoring_recency_weight, + "priority": self.scoring_priority_weight, + } + + def to_dict(self) -> dict[str, Any]: + """Convert settings to dictionary for logging/debugging.""" + return { + "budget": self.get_budget_allocation(), + "scoring": self.get_scoring_weights(), + "compression": { + "threshold": self.compression_threshold, + "summary_model_group": self.summary_model_group, + }, + "cache": { + "enabled": self.cache_enabled, + "ttl_seconds": self.cache_ttl_seconds, + "prefix": self.cache_prefix, + }, + "performance": { + "max_assembly_time_ms": self.max_assembly_time_ms, + "parallel_scoring": self.parallel_scoring, + "max_parallel_scores": self.max_parallel_scores, + }, + "knowledge": { + "search_type": self.knowledge_search_type, + "max_results": self.knowledge_max_results, + "min_score": self.knowledge_min_score, + }, + "conversation": { + "max_turns": self.conversation_max_turns, + "recent_priority": self.conversation_recent_priority, + }, + } + + model_config = { + "env_prefix": "CTX_", + "env_file": "../.env", + "env_file_encoding": "utf-8", + "case_sensitive": False, + "extra": "ignore", + } + + +# Thread-safe singleton pattern +_settings: ContextSettings | None = None +_settings_lock = threading.Lock() + + +def get_context_settings() -> ContextSettings: + """ + Get the global ContextSettings instance. + + Thread-safe with double-checked locking pattern. + + Returns: + ContextSettings instance + """ + global _settings + if _settings is None: + with _settings_lock: + if _settings is None: + _settings = ContextSettings() + return _settings + + +def reset_context_settings() -> None: + """ + Reset the global settings instance. + + Primarily used for testing. + """ + global _settings + with _settings_lock: + _settings = None + + +@lru_cache(maxsize=1) +def get_default_settings() -> ContextSettings: + """ + Get default settings (cached). + + Use this for read-only access to defaults. + For mutable access, use get_context_settings(). + """ + return ContextSettings() diff --git a/backend/app/services/context/exceptions.py b/backend/app/services/context/exceptions.py new file mode 100644 index 0000000..18f7910 --- /dev/null +++ b/backend/app/services/context/exceptions.py @@ -0,0 +1,354 @@ +""" +Context Management Engine Exceptions. + +Provides a hierarchy of exceptions for context assembly, +token budget management, and related operations. +""" + +from typing import Any + + +class ContextError(Exception): + """ + Base exception for all context management errors. + + All context-related exceptions should inherit from this class + to allow for catch-all handling when needed. + """ + + def __init__(self, message: str, details: dict[str, Any] | None = None) -> None: + """ + Initialize context error. + + Args: + message: Human-readable error message + details: Optional dict with additional error context + """ + self.message = message + self.details = details or {} + super().__init__(message) + + def to_dict(self) -> dict[str, Any]: + """Convert exception to dictionary for logging/serialization.""" + return { + "error_type": self.__class__.__name__, + "message": self.message, + "details": self.details, + } + + +class BudgetExceededError(ContextError): + """ + Raised when token budget is exceeded. + + This occurs when the assembled context would exceed the + allocated token budget for a specific context type or total. + """ + + def __init__( + self, + message: str = "Token budget exceeded", + allocated: int = 0, + requested: int = 0, + context_type: str | None = None, + ) -> None: + """ + Initialize budget exceeded error. + + Args: + message: Error message + allocated: Tokens allocated for this context type + requested: Tokens requested + context_type: Type of context that exceeded budget + """ + details = { + "allocated": allocated, + "requested": requested, + "overage": requested - allocated, + } + if context_type: + details["context_type"] = context_type + + super().__init__(message, details) + self.allocated = allocated + self.requested = requested + self.context_type = context_type + + +class TokenCountError(ContextError): + """ + Raised when token counting fails. + + This typically occurs when the LLM Gateway token counting + service is unavailable or returns an error. + """ + + def __init__( + self, + message: str = "Failed to count tokens", + model: str | None = None, + text_length: int | None = None, + ) -> None: + """ + Initialize token count error. + + Args: + message: Error message + model: Model for which counting was attempted + text_length: Length of text that failed to count + """ + details: dict[str, Any] = {} + if model: + details["model"] = model + if text_length is not None: + details["text_length"] = text_length + + super().__init__(message, details) + self.model = model + self.text_length = text_length + + +class CompressionError(ContextError): + """ + Raised when context compression fails. + + This can occur when summarization or truncation cannot + reduce content to fit within the budget. + """ + + def __init__( + self, + message: str = "Failed to compress context", + original_tokens: int | None = None, + target_tokens: int | None = None, + achieved_tokens: int | None = None, + ) -> None: + """ + Initialize compression error. + + Args: + message: Error message + original_tokens: Tokens before compression + target_tokens: Target token count + achieved_tokens: Tokens achieved after compression attempt + """ + details: dict[str, Any] = {} + if original_tokens is not None: + details["original_tokens"] = original_tokens + if target_tokens is not None: + details["target_tokens"] = target_tokens + if achieved_tokens is not None: + details["achieved_tokens"] = achieved_tokens + + super().__init__(message, details) + self.original_tokens = original_tokens + self.target_tokens = target_tokens + self.achieved_tokens = achieved_tokens + + +class AssemblyTimeoutError(ContextError): + """ + Raised when context assembly exceeds time limit. + + Context assembly must complete within a configurable + time limit to maintain responsiveness. + """ + + def __init__( + self, + message: str = "Context assembly timed out", + timeout_ms: int = 0, + elapsed_ms: float = 0.0, + stage: str | None = None, + ) -> None: + """ + Initialize assembly timeout error. + + Args: + message: Error message + timeout_ms: Configured timeout in milliseconds + elapsed_ms: Actual elapsed time in milliseconds + stage: Pipeline stage where timeout occurred + """ + details = { + "timeout_ms": timeout_ms, + "elapsed_ms": round(elapsed_ms, 2), + } + if stage: + details["stage"] = stage + + super().__init__(message, details) + self.timeout_ms = timeout_ms + self.elapsed_ms = elapsed_ms + self.stage = stage + + +class ScoringError(ContextError): + """ + Raised when context scoring fails. + + This occurs when relevance, recency, or priority scoring + encounters an error. + """ + + def __init__( + self, + message: str = "Failed to score context", + scorer_type: str | None = None, + context_id: str | None = None, + ) -> None: + """ + Initialize scoring error. + + Args: + message: Error message + scorer_type: Type of scorer that failed + context_id: ID of context being scored + """ + details: dict[str, Any] = {} + if scorer_type: + details["scorer_type"] = scorer_type + if context_id: + details["context_id"] = context_id + + super().__init__(message, details) + self.scorer_type = scorer_type + self.context_id = context_id + + +class FormattingError(ContextError): + """ + Raised when context formatting fails. + + This occurs when converting assembled context to + model-specific format fails. + """ + + def __init__( + self, + message: str = "Failed to format context", + model: str | None = None, + adapter: str | None = None, + ) -> None: + """ + Initialize formatting error. + + Args: + message: Error message + model: Target model + adapter: Adapter that failed + """ + details: dict[str, Any] = {} + if model: + details["model"] = model + if adapter: + details["adapter"] = adapter + + super().__init__(message, details) + self.model = model + self.adapter = adapter + + +class CacheError(ContextError): + """ + Raised when cache operations fail. + + This is typically non-fatal and should be handled + gracefully by falling back to recomputation. + """ + + def __init__( + self, + message: str = "Cache operation failed", + operation: str | None = None, + cache_key: str | None = None, + ) -> None: + """ + Initialize cache error. + + Args: + message: Error message + operation: Cache operation that failed (get, set, delete) + cache_key: Key involved in the failed operation + """ + details: dict[str, Any] = {} + if operation: + details["operation"] = operation + if cache_key: + details["cache_key"] = cache_key + + super().__init__(message, details) + self.operation = operation + self.cache_key = cache_key + + +class ContextNotFoundError(ContextError): + """ + Raised when expected context is not found. + + This occurs when required context sources return + no results or are unavailable. + """ + + def __init__( + self, + message: str = "Required context not found", + source: str | None = None, + query: str | None = None, + ) -> None: + """ + Initialize context not found error. + + Args: + message: Error message + source: Source that returned no results + query: Query used to search + """ + details: dict[str, Any] = {} + if source: + details["source"] = source + if query: + details["query"] = query + + super().__init__(message, details) + self.source = source + self.query = query + + +class InvalidContextError(ContextError): + """ + Raised when context data is invalid. + + This occurs when context content or metadata + fails validation. + """ + + def __init__( + self, + message: str = "Invalid context data", + field: str | None = None, + value: Any | None = None, + reason: str | None = None, + ) -> None: + """ + Initialize invalid context error. + + Args: + message: Error message + field: Field that is invalid + value: Invalid value (may be redacted for security) + reason: Reason for invalidity + """ + details: dict[str, Any] = {} + if field: + details["field"] = field + if value is not None: + # Avoid logging potentially sensitive values + details["value_type"] = type(value).__name__ + if reason: + details["reason"] = reason + + super().__init__(message, details) + self.field = field + self.value = value + self.reason = reason diff --git a/backend/app/services/context/prioritization/__init__.py b/backend/app/services/context/prioritization/__init__.py new file mode 100644 index 0000000..66f586f --- /dev/null +++ b/backend/app/services/context/prioritization/__init__.py @@ -0,0 +1,5 @@ +""" +Context Prioritization Module. + +Provides context ranking and selection. +""" diff --git a/backend/app/services/context/scoring/__init__.py b/backend/app/services/context/scoring/__init__.py new file mode 100644 index 0000000..f0b7218 --- /dev/null +++ b/backend/app/services/context/scoring/__init__.py @@ -0,0 +1,5 @@ +""" +Context Scoring Module. + +Provides relevance, recency, and composite scoring. +""" diff --git a/backend/app/services/context/types/__init__.py b/backend/app/services/context/types/__init__.py new file mode 100644 index 0000000..d247bfb --- /dev/null +++ b/backend/app/services/context/types/__init__.py @@ -0,0 +1,49 @@ +""" +Context Types Module. + +Provides all context types used in the Context Management Engine. +""" + +from .base import ( + AssembledContext, + BaseContext, + ContextPriority, + ContextType, +) +from .conversation import ( + ConversationContext, + MessageRole, +) +from .knowledge import KnowledgeContext +from .system import SystemContext +from .task import ( + TaskComplexity, + TaskContext, + TaskStatus, +) +from .tool import ( + ToolContext, + ToolResultStatus, +) + +__all__ = [ + # Base types + "AssembledContext", + "BaseContext", + "ContextPriority", + "ContextType", + # Conversation + "ConversationContext", + "MessageRole", + # Knowledge + "KnowledgeContext", + # System + "SystemContext", + # Task + "TaskComplexity", + "TaskContext", + "TaskStatus", + # Tool + "ToolContext", + "ToolResultStatus", +] diff --git a/backend/app/services/context/types/base.py b/backend/app/services/context/types/base.py new file mode 100644 index 0000000..4b01ed2 --- /dev/null +++ b/backend/app/services/context/types/base.py @@ -0,0 +1,320 @@ +""" +Base Context Types and Enums. + +Provides the foundation for all context types used in +the Context Management Engine. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import UTC, datetime +from enum import Enum +from typing import Any +from uuid import uuid4 + + +class ContextType(str, Enum): + """ + Types of context that can be assembled. + + Each type has specific handling, formatting, and + budget allocation rules. + """ + + SYSTEM = "system" + TASK = "task" + KNOWLEDGE = "knowledge" + CONVERSATION = "conversation" + TOOL = "tool" + + @classmethod + def from_string(cls, value: str) -> "ContextType": + """ + Convert string to ContextType. + + Args: + value: String value + + Returns: + ContextType enum value + + Raises: + ValueError: If value is not a valid context type + """ + try: + return cls(value.lower()) + except ValueError: + valid = ", ".join(t.value for t in cls) + raise ValueError(f"Invalid context type '{value}'. Valid types: {valid}") + + +class ContextPriority(int, Enum): + """ + Priority levels for context ordering. + + Higher values indicate higher priority. + """ + + LOWEST = 0 + LOW = 25 + NORMAL = 50 + HIGH = 75 + HIGHEST = 100 + CRITICAL = 150 # Never omit + + @classmethod + def from_int(cls, value: int) -> "ContextPriority": + """ + Get closest priority level for an integer. + + Args: + value: Integer priority value + + Returns: + Closest ContextPriority enum value + """ + priorities = sorted(cls, key=lambda p: p.value) + for priority in reversed(priorities): + if value >= priority.value: + return priority + return cls.LOWEST + + +@dataclass(eq=False) +class BaseContext(ABC): + """ + Abstract base class for all context types. + + Provides common fields and methods for context handling, + scoring, and serialization. + """ + + # Required fields + content: str + source: str + + # Optional fields with defaults + id: str = field(default_factory=lambda: str(uuid4())) + timestamp: datetime = field(default_factory=lambda: datetime.now(UTC)) + priority: int = field(default=ContextPriority.NORMAL.value) + metadata: dict[str, Any] = field(default_factory=dict) + + # Computed/cached fields + _token_count: int | None = field(default=None, repr=False) + _score: float | None = field(default=None, repr=False) + + @property + def token_count(self) -> int | None: + """Get cached token count (None if not counted yet).""" + return self._token_count + + @token_count.setter + def token_count(self, value: int) -> None: + """Set token count.""" + self._token_count = value + + @property + def score(self) -> float | None: + """Get cached score (None if not scored yet).""" + return self._score + + @score.setter + def score(self, value: float) -> None: + """Set score (clamped to 0.0-1.0).""" + self._score = max(0.0, min(1.0, value)) + + @abstractmethod + def get_type(self) -> ContextType: + """ + Get the type of this context. + + Returns: + ContextType enum value + """ + ... + + def get_age_seconds(self) -> float: + """ + Get age of context in seconds. + + Returns: + Age in seconds since creation + """ + now = datetime.now(UTC) + delta = now - self.timestamp + return delta.total_seconds() + + def get_age_hours(self) -> float: + """ + Get age of context in hours. + + Returns: + Age in hours since creation + """ + return self.get_age_seconds() / 3600 + + def is_stale(self, max_age_hours: float = 168.0) -> bool: + """ + Check if context is stale. + + Args: + max_age_hours: Maximum age before considered stale (default 7 days) + + Returns: + True if context is older than max_age_hours + """ + return self.get_age_hours() > max_age_hours + + def to_dict(self) -> dict[str, Any]: + """ + Convert context to dictionary for serialization. + + Returns: + Dictionary representation + """ + return { + "id": self.id, + "type": self.get_type().value, + "content": self.content, + "source": self.source, + "timestamp": self.timestamp.isoformat(), + "priority": self.priority, + "metadata": self.metadata, + "token_count": self._token_count, + "score": self._score, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "BaseContext": + """ + Create context from dictionary. + + Note: Subclasses should override this to return correct type. + + Args: + data: Dictionary with context data + + Returns: + Context instance + """ + raise NotImplementedError("Subclasses must implement from_dict") + + def truncate(self, max_tokens: int, suffix: str = "... [truncated]") -> str: + """ + Truncate content to fit within token limit. + + This is a rough estimation based on characters. + For accurate truncation, use the TokenCalculator. + + Args: + max_tokens: Maximum tokens allowed + suffix: Suffix to append when truncated + + Returns: + Truncated content + """ + if self._token_count is None or self._token_count <= max_tokens: + return self.content + + # Rough estimation: 4 chars per token on average + estimated_chars = max_tokens * 4 + suffix_chars = len(suffix) + + if len(self.content) <= estimated_chars: + return self.content + + truncated = self.content[: estimated_chars - suffix_chars] + # Try to break at word boundary + last_space = truncated.rfind(" ") + if last_space > estimated_chars * 0.8: + truncated = truncated[:last_space] + + return truncated + suffix + + def __hash__(self) -> int: + """Hash based on ID for set/dict usage.""" + return hash(self.id) + + def __eq__(self, other: object) -> bool: + """Equality based on ID.""" + if not isinstance(other, BaseContext): + return False + return self.id == other.id + + +@dataclass +class AssembledContext: + """ + Result of context assembly. + + Contains the final formatted context ready for LLM consumption, + along with metadata about the assembly process. + """ + + # Main content + content: str + token_count: int + + # Assembly metadata + contexts_included: int + contexts_excluded: int = 0 + assembly_time_ms: float = 0.0 + + # Budget tracking + budget_total: int = 0 + budget_used: int = 0 + + # Context breakdown + by_type: dict[str, int] = field(default_factory=dict) + + # Cache info + cache_hit: bool = False + cache_key: str | None = None + + @property + def budget_utilization(self) -> float: + """Get budget utilization percentage.""" + if self.budget_total == 0: + return 0.0 + return self.budget_used / self.budget_total + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary.""" + return { + "content": self.content, + "token_count": self.token_count, + "contexts_included": self.contexts_included, + "contexts_excluded": self.contexts_excluded, + "assembly_time_ms": round(self.assembly_time_ms, 2), + "budget_total": self.budget_total, + "budget_used": self.budget_used, + "budget_utilization": round(self.budget_utilization, 3), + "by_type": self.by_type, + "cache_hit": self.cache_hit, + "cache_key": self.cache_key, + } + + def to_json(self) -> str: + """Convert to JSON string.""" + import json + + return json.dumps(self.to_dict()) + + @classmethod + def from_json(cls, json_str: str) -> "AssembledContext": + """Create from JSON string.""" + import json + + data = json.loads(json_str) + return cls( + content=data["content"], + token_count=data["token_count"], + contexts_included=data["contexts_included"], + contexts_excluded=data.get("contexts_excluded", 0), + assembly_time_ms=data.get("assembly_time_ms", 0.0), + budget_total=data.get("budget_total", 0), + budget_used=data.get("budget_used", 0), + by_type=data.get("by_type", {}), + cache_hit=data.get("cache_hit", False), + cache_key=data.get("cache_key"), + ) diff --git a/backend/app/services/context/types/conversation.py b/backend/app/services/context/types/conversation.py new file mode 100644 index 0000000..2adf704 --- /dev/null +++ b/backend/app/services/context/types/conversation.py @@ -0,0 +1,182 @@ +""" +Conversation Context Type. + +Represents conversation history for context continuity. +""" + +from dataclasses import dataclass, field +from datetime import UTC, datetime +from enum import Enum +from typing import Any + +from .base import BaseContext, ContextPriority, ContextType + + +class MessageRole(str, Enum): + """Roles for conversation messages.""" + + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + TOOL = "tool" + + @classmethod + def from_string(cls, value: str) -> "MessageRole": + """Convert string to MessageRole.""" + try: + return cls(value.lower()) + except ValueError: + # Default to user for unknown roles + return cls.USER + + +@dataclass(eq=False) +class ConversationContext(BaseContext): + """ + Context from conversation history. + + Represents a single turn in the conversation, + including user messages, assistant responses, + and tool results. + """ + + # Conversation-specific fields + role: MessageRole = field(default=MessageRole.USER) + turn_index: int = field(default=0) + session_id: str | None = field(default=None) + parent_message_id: str | None = field(default=None) + + def get_type(self) -> ContextType: + """Return CONVERSATION context type.""" + return ContextType.CONVERSATION + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary with conversation-specific fields.""" + base = super().to_dict() + base.update( + { + "role": self.role.value, + "turn_index": self.turn_index, + "session_id": self.session_id, + "parent_message_id": self.parent_message_id, + } + ) + return base + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ConversationContext": + """Create ConversationContext from dictionary.""" + role = data.get("role", "user") + if isinstance(role, str): + role = MessageRole.from_string(role) + + return cls( + id=data.get("id", ""), + content=data["content"], + source=data.get("source", "conversation"), + timestamp=datetime.fromisoformat(data["timestamp"]) + if isinstance(data.get("timestamp"), str) + else data.get("timestamp", datetime.now(UTC)), + priority=data.get("priority", ContextPriority.NORMAL.value), + metadata=data.get("metadata", {}), + role=role, + turn_index=data.get("turn_index", 0), + session_id=data.get("session_id"), + parent_message_id=data.get("parent_message_id"), + ) + + @classmethod + def from_message( + cls, + content: str, + role: str | MessageRole, + turn_index: int = 0, + session_id: str | None = None, + timestamp: datetime | None = None, + ) -> "ConversationContext": + """ + Create ConversationContext from a message. + + Args: + content: Message content + role: Message role (user, assistant, system, tool) + turn_index: Position in conversation + session_id: Session identifier + timestamp: Message timestamp + + Returns: + ConversationContext instance + """ + if isinstance(role, str): + role = MessageRole.from_string(role) + + # Recent messages have higher priority + priority = ContextPriority.NORMAL.value + + return cls( + content=content, + source="conversation", + role=role, + turn_index=turn_index, + session_id=session_id, + timestamp=timestamp or datetime.now(UTC), + priority=priority, + ) + + @classmethod + def from_history( + cls, + messages: list[dict[str, Any]], + session_id: str | None = None, + ) -> list["ConversationContext"]: + """ + Create multiple ConversationContexts from message history. + + Args: + messages: List of message dicts with 'role' and 'content' + session_id: Session identifier + + Returns: + List of ConversationContext instances + """ + contexts = [] + for i, msg in enumerate(messages): + ctx = cls.from_message( + content=msg.get("content", ""), + role=msg.get("role", "user"), + turn_index=i, + session_id=session_id, + timestamp=datetime.fromisoformat(msg["timestamp"]) + if "timestamp" in msg + else None, + ) + contexts.append(ctx) + return contexts + + def is_user_message(self) -> bool: + """Check if this is a user message.""" + return self.role == MessageRole.USER + + def is_assistant_message(self) -> bool: + """Check if this is an assistant message.""" + return self.role == MessageRole.ASSISTANT + + def is_tool_result(self) -> bool: + """Check if this is a tool result.""" + return self.role == MessageRole.TOOL + + def format_for_prompt(self) -> str: + """ + Format message for inclusion in prompt. + + Returns: + Formatted message string + """ + role_labels = { + MessageRole.USER: "User", + MessageRole.ASSISTANT: "Assistant", + MessageRole.SYSTEM: "System", + MessageRole.TOOL: "Tool Result", + } + label = role_labels.get(self.role, "Unknown") + return f"{label}: {self.content}" diff --git a/backend/app/services/context/types/knowledge.py b/backend/app/services/context/types/knowledge.py new file mode 100644 index 0000000..9e66819 --- /dev/null +++ b/backend/app/services/context/types/knowledge.py @@ -0,0 +1,143 @@ +""" +Knowledge Context Type. + +Represents RAG results from the Knowledge Base MCP server. +""" + +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + +from .base import BaseContext, ContextPriority, ContextType + + +@dataclass(eq=False) +class KnowledgeContext(BaseContext): + """ + Context from knowledge base / RAG retrieval. + + Knowledge context represents chunks retrieved from the + Knowledge Base MCP server, including: + - Code snippets + - Documentation + - Previous conversations + - External knowledge + + Each chunk includes relevance scoring from the search. + """ + + # Knowledge-specific fields + collection: str = field(default="default") + file_type: str | None = field(default=None) + chunk_index: int = field(default=0) + relevance_score: float = field(default=0.0) + search_query: str = field(default="") + + def get_type(self) -> ContextType: + """Return KNOWLEDGE context type.""" + return ContextType.KNOWLEDGE + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary with knowledge-specific fields.""" + base = super().to_dict() + base.update( + { + "collection": self.collection, + "file_type": self.file_type, + "chunk_index": self.chunk_index, + "relevance_score": self.relevance_score, + "search_query": self.search_query, + } + ) + return base + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "KnowledgeContext": + """Create KnowledgeContext from dictionary.""" + return cls( + id=data.get("id", ""), + content=data["content"], + source=data["source"], + timestamp=datetime.fromisoformat(data["timestamp"]) + if isinstance(data.get("timestamp"), str) + else data.get("timestamp", datetime.now(UTC)), + priority=data.get("priority", ContextPriority.NORMAL.value), + metadata=data.get("metadata", {}), + collection=data.get("collection", "default"), + file_type=data.get("file_type"), + chunk_index=data.get("chunk_index", 0), + relevance_score=data.get("relevance_score", 0.0), + search_query=data.get("search_query", ""), + ) + + @classmethod + def from_search_result( + cls, + result: dict[str, Any], + query: str, + ) -> "KnowledgeContext": + """ + Create KnowledgeContext from a Knowledge Base search result. + + Args: + result: Search result from Knowledge Base MCP + query: Search query used + + Returns: + KnowledgeContext instance + """ + return cls( + content=result.get("content", ""), + source=result.get("source_path", "unknown"), + collection=result.get("collection", "default"), + file_type=result.get("file_type"), + chunk_index=result.get("chunk_index", 0), + relevance_score=result.get("score", 0.0), + search_query=query, + metadata={ + "chunk_id": result.get("id"), + "content_hash": result.get("content_hash"), + }, + ) + + @classmethod + def from_search_results( + cls, + results: list[dict[str, Any]], + query: str, + ) -> list["KnowledgeContext"]: + """ + Create multiple KnowledgeContexts from search results. + + Args: + results: List of search results + query: Search query used + + Returns: + List of KnowledgeContext instances + """ + return [cls.from_search_result(r, query) for r in results] + + def is_code(self) -> bool: + """Check if this is code content.""" + code_types = {"python", "javascript", "typescript", "go", "rust", "java", "c", "cpp"} + return self.file_type is not None and self.file_type.lower() in code_types + + def is_documentation(self) -> bool: + """Check if this is documentation content.""" + doc_types = {"markdown", "rst", "txt", "md"} + return self.file_type is not None and self.file_type.lower() in doc_types + + def get_formatted_source(self) -> str: + """ + Get a formatted source string for display. + + Returns: + Formatted source string + """ + parts = [self.source] + if self.file_type: + parts.append(f"({self.file_type})") + if self.collection != "default": + parts.insert(0, f"[{self.collection}]") + return " ".join(parts) diff --git a/backend/app/services/context/types/system.py b/backend/app/services/context/types/system.py new file mode 100644 index 0000000..2199d43 --- /dev/null +++ b/backend/app/services/context/types/system.py @@ -0,0 +1,138 @@ +""" +System Context Type. + +Represents system prompts, instructions, and agent personas. +""" + +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + +from .base import BaseContext, ContextPriority, ContextType + + +@dataclass(eq=False) +class SystemContext(BaseContext): + """ + Context for system prompts and instructions. + + System context typically includes: + - Agent persona and role definitions + - Behavioral instructions + - Safety guidelines + - Output format requirements + + System context is usually high priority and should + rarely be truncated or omitted. + """ + + # System context specific fields + role: str = field(default="assistant") + instructions_type: str = field(default="general") + + def __post_init__(self) -> None: + """Set high priority for system context.""" + # System context defaults to high priority + if self.priority == ContextPriority.NORMAL.value: + self.priority = ContextPriority.HIGH.value + + def get_type(self) -> ContextType: + """Return SYSTEM context type.""" + return ContextType.SYSTEM + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary with system-specific fields.""" + base = super().to_dict() + base.update( + { + "role": self.role, + "instructions_type": self.instructions_type, + } + ) + return base + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "SystemContext": + """Create SystemContext from dictionary.""" + return cls( + id=data.get("id", ""), + content=data["content"], + source=data["source"], + timestamp=datetime.fromisoformat(data["timestamp"]) + if isinstance(data.get("timestamp"), str) + else data.get("timestamp", datetime.now(UTC)), + priority=data.get("priority", ContextPriority.HIGH.value), + metadata=data.get("metadata", {}), + role=data.get("role", "assistant"), + instructions_type=data.get("instructions_type", "general"), + ) + + @classmethod + def create_persona( + cls, + name: str, + description: str, + capabilities: list[str] | None = None, + constraints: list[str] | None = None, + ) -> "SystemContext": + """ + Create a persona system context. + + Args: + name: Agent name/role + description: Role description + capabilities: List of things the agent can do + constraints: List of limitations + + Returns: + SystemContext with formatted persona + """ + parts = [f"You are {name}.", "", description] + + if capabilities: + parts.append("") + parts.append("You can:") + for cap in capabilities: + parts.append(f"- {cap}") + + if constraints: + parts.append("") + parts.append("You must not:") + for constraint in constraints: + parts.append(f"- {constraint}") + + return cls( + content="\n".join(parts), + source="persona_builder", + role=name.lower().replace(" ", "_"), + instructions_type="persona", + priority=ContextPriority.HIGHEST.value, + ) + + @classmethod + def create_instructions( + cls, + instructions: str | list[str], + source: str = "instructions", + ) -> "SystemContext": + """ + Create an instructions system context. + + Args: + instructions: Instructions string or list of instruction strings + source: Source identifier + + Returns: + SystemContext with instructions + """ + if isinstance(instructions, list): + content = "\n".join(f"- {inst}" for inst in instructions) + else: + content = instructions + + return cls( + content=content, + source=source, + instructions_type="instructions", + priority=ContextPriority.HIGH.value, + ) diff --git a/backend/app/services/context/types/task.py b/backend/app/services/context/types/task.py new file mode 100644 index 0000000..e1765f2 --- /dev/null +++ b/backend/app/services/context/types/task.py @@ -0,0 +1,195 @@ +""" +Task Context Type. + +Represents the current task or objective for the agent. +""" + +from dataclasses import dataclass, field +from datetime import UTC, datetime +from enum import Enum +from typing import Any + +from .base import BaseContext, ContextPriority, ContextType + + +class TaskStatus(str, Enum): + """Status of a task.""" + + PENDING = "pending" + IN_PROGRESS = "in_progress" + BLOCKED = "blocked" + COMPLETED = "completed" + FAILED = "failed" + + +class TaskComplexity(str, Enum): + """Complexity level of a task.""" + + TRIVIAL = "trivial" + SIMPLE = "simple" + MODERATE = "moderate" + COMPLEX = "complex" + VERY_COMPLEX = "very_complex" + + +@dataclass(eq=False) +class TaskContext(BaseContext): + """ + Context for the current task or objective. + + Task context provides information about what the agent + should accomplish, including: + - Task description and goals + - Acceptance criteria + - Constraints and requirements + - Related issue/ticket information + """ + + # Task-specific fields + title: str = field(default="") + status: TaskStatus = field(default=TaskStatus.PENDING) + complexity: TaskComplexity = field(default=TaskComplexity.MODERATE) + issue_id: str | None = field(default=None) + project_id: str | None = field(default=None) + acceptance_criteria: list[str] = field(default_factory=list) + constraints: list[str] = field(default_factory=list) + parent_task_id: str | None = field(default=None) + + def __post_init__(self) -> None: + """Set high priority for task context.""" + # Task context defaults to high priority + if self.priority == ContextPriority.NORMAL.value: + self.priority = ContextPriority.HIGH.value + + def get_type(self) -> ContextType: + """Return TASK context type.""" + return ContextType.TASK + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary with task-specific fields.""" + base = super().to_dict() + base.update( + { + "title": self.title, + "status": self.status.value, + "complexity": self.complexity.value, + "issue_id": self.issue_id, + "project_id": self.project_id, + "acceptance_criteria": self.acceptance_criteria, + "constraints": self.constraints, + "parent_task_id": self.parent_task_id, + } + ) + return base + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "TaskContext": + """Create TaskContext from dictionary.""" + status = data.get("status", "pending") + if isinstance(status, str): + status = TaskStatus(status) + + complexity = data.get("complexity", "moderate") + if isinstance(complexity, str): + complexity = TaskComplexity(complexity) + + return cls( + id=data.get("id", ""), + content=data["content"], + source=data.get("source", "task"), + timestamp=datetime.fromisoformat(data["timestamp"]) + if isinstance(data.get("timestamp"), str) + else data.get("timestamp", datetime.now(UTC)), + priority=data.get("priority", ContextPriority.HIGH.value), + metadata=data.get("metadata", {}), + title=data.get("title", ""), + status=status, + complexity=complexity, + issue_id=data.get("issue_id"), + project_id=data.get("project_id"), + acceptance_criteria=data.get("acceptance_criteria", []), + constraints=data.get("constraints", []), + parent_task_id=data.get("parent_task_id"), + ) + + @classmethod + def create( + cls, + title: str, + description: str, + acceptance_criteria: list[str] | None = None, + constraints: list[str] | None = None, + issue_id: str | None = None, + project_id: str | None = None, + complexity: TaskComplexity | str = TaskComplexity.MODERATE, + ) -> "TaskContext": + """ + Create a task context. + + Args: + title: Task title + description: Task description + acceptance_criteria: List of acceptance criteria + constraints: List of constraints + issue_id: Related issue ID + project_id: Project ID + complexity: Task complexity + + Returns: + TaskContext instance + """ + if isinstance(complexity, str): + complexity = TaskComplexity(complexity) + + return cls( + content=description, + source=f"task:{issue_id}" if issue_id else "task", + title=title, + status=TaskStatus.IN_PROGRESS, + complexity=complexity, + issue_id=issue_id, + project_id=project_id, + acceptance_criteria=acceptance_criteria or [], + constraints=constraints or [], + ) + + def format_for_prompt(self) -> str: + """ + Format task for inclusion in prompt. + + Returns: + Formatted task string + """ + parts = [] + + if self.title: + parts.append(f"Task: {self.title}") + parts.append("") + + parts.append(self.content) + + if self.acceptance_criteria: + parts.append("") + parts.append("Acceptance Criteria:") + for criterion in self.acceptance_criteria: + parts.append(f"- {criterion}") + + if self.constraints: + parts.append("") + parts.append("Constraints:") + for constraint in self.constraints: + parts.append(f"- {constraint}") + + return "\n".join(parts) + + def is_active(self) -> bool: + """Check if task is currently active.""" + return self.status in (TaskStatus.PENDING, TaskStatus.IN_PROGRESS) + + def is_complete(self) -> bool: + """Check if task is complete.""" + return self.status == TaskStatus.COMPLETED + + def is_blocked(self) -> bool: + """Check if task is blocked.""" + return self.status == TaskStatus.BLOCKED diff --git a/backend/app/services/context/types/tool.py b/backend/app/services/context/types/tool.py new file mode 100644 index 0000000..e4c1678 --- /dev/null +++ b/backend/app/services/context/types/tool.py @@ -0,0 +1,207 @@ +""" +Tool Context Type. + +Represents available tools and recent tool execution results. +""" + +from dataclasses import dataclass, field +from datetime import UTC, datetime +from enum import Enum +from typing import Any + +from .base import BaseContext, ContextPriority, ContextType + + +class ToolResultStatus(str, Enum): + """Status of a tool execution result.""" + + SUCCESS = "success" + ERROR = "error" + TIMEOUT = "timeout" + CANCELLED = "cancelled" + + +@dataclass(eq=False) +class ToolContext(BaseContext): + """ + Context for tools and tool execution results. + + Tool context includes: + - Tool descriptions and parameters + - Recent tool execution results + - Tool availability information + + This helps the LLM understand what tools are available + and what results previous tool calls produced. + """ + + # Tool-specific fields + tool_name: str = field(default="") + tool_description: str = field(default="") + is_result: bool = field(default=False) + result_status: ToolResultStatus | None = field(default=None) + execution_time_ms: float | None = field(default=None) + parameters: dict[str, Any] = field(default_factory=dict) + server_name: str | None = field(default=None) + + def get_type(self) -> ContextType: + """Return TOOL context type.""" + return ContextType.TOOL + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary with tool-specific fields.""" + base = super().to_dict() + base.update( + { + "tool_name": self.tool_name, + "tool_description": self.tool_description, + "is_result": self.is_result, + "result_status": self.result_status.value if self.result_status else None, + "execution_time_ms": self.execution_time_ms, + "parameters": self.parameters, + "server_name": self.server_name, + } + ) + return base + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ToolContext": + """Create ToolContext from dictionary.""" + result_status = data.get("result_status") + if isinstance(result_status, str): + result_status = ToolResultStatus(result_status) + + return cls( + id=data.get("id", ""), + content=data["content"], + source=data.get("source", "tool"), + timestamp=datetime.fromisoformat(data["timestamp"]) + if isinstance(data.get("timestamp"), str) + else data.get("timestamp", datetime.now(UTC)), + priority=data.get("priority", ContextPriority.NORMAL.value), + metadata=data.get("metadata", {}), + tool_name=data.get("tool_name", ""), + tool_description=data.get("tool_description", ""), + is_result=data.get("is_result", False), + result_status=result_status, + execution_time_ms=data.get("execution_time_ms"), + parameters=data.get("parameters", {}), + server_name=data.get("server_name"), + ) + + @classmethod + def from_tool_definition( + cls, + name: str, + description: str, + parameters: dict[str, Any] | None = None, + server_name: str | None = None, + ) -> "ToolContext": + """ + Create a ToolContext from a tool definition. + + Args: + name: Tool name + description: Tool description + parameters: Tool parameter schema + server_name: MCP server name + + Returns: + ToolContext instance + """ + # Format content as tool documentation + content_parts = [f"Tool: {name}", "", description] + + if parameters: + content_parts.append("") + content_parts.append("Parameters:") + for param_name, param_info in parameters.items(): + param_type = param_info.get("type", "any") + param_desc = param_info.get("description", "") + required = param_info.get("required", False) + req_marker = " (required)" if required else "" + content_parts.append(f" - {param_name}: {param_type}{req_marker}") + if param_desc: + content_parts.append(f" {param_desc}") + + return cls( + content="\n".join(content_parts), + source=f"tool:{server_name}:{name}" if server_name else f"tool:{name}", + tool_name=name, + tool_description=description, + is_result=False, + parameters=parameters or {}, + server_name=server_name, + priority=ContextPriority.LOW.value, + ) + + @classmethod + def from_tool_result( + cls, + tool_name: str, + result: Any, + status: ToolResultStatus = ToolResultStatus.SUCCESS, + execution_time_ms: float | None = None, + parameters: dict[str, Any] | None = None, + server_name: str | None = None, + ) -> "ToolContext": + """ + Create a ToolContext from a tool execution result. + + Args: + tool_name: Name of the tool that was executed + result: Result content (will be converted to string) + status: Execution status + execution_time_ms: Execution time in milliseconds + parameters: Parameters that were passed to the tool + server_name: MCP server name + + Returns: + ToolContext instance + """ + # Convert result to string content + if isinstance(result, str): + content = result + elif isinstance(result, dict): + import json + + try: + content = json.dumps(result, indent=2) + except (TypeError, ValueError): + content = str(result) + else: + content = str(result) + + return cls( + content=content, + source=f"tool_result:{server_name}:{tool_name}" if server_name else f"tool_result:{tool_name}", + tool_name=tool_name, + is_result=True, + result_status=status, + execution_time_ms=execution_time_ms, + parameters=parameters or {}, + server_name=server_name, + priority=ContextPriority.HIGH.value, # Recent results are high priority + ) + + def is_successful(self) -> bool: + """Check if this is a successful tool result.""" + return self.is_result and self.result_status == ToolResultStatus.SUCCESS + + def is_error(self) -> bool: + """Check if this is an error result.""" + return self.is_result and self.result_status == ToolResultStatus.ERROR + + def format_for_prompt(self) -> str: + """ + Format tool context for inclusion in prompt. + + Returns: + Formatted tool string + """ + if self.is_result: + status_str = self.result_status.value if self.result_status else "unknown" + header = f"Tool Result ({self.tool_name}, {status_str}):" + return f"{header}\n{self.content}" + else: + return self.content diff --git a/backend/tests/services/context/__init__.py b/backend/tests/services/context/__init__.py new file mode 100644 index 0000000..4554649 --- /dev/null +++ b/backend/tests/services/context/__init__.py @@ -0,0 +1 @@ +"""Tests for Context Management Engine.""" diff --git a/backend/tests/services/context/test_config.py b/backend/tests/services/context/test_config.py new file mode 100644 index 0000000..f797866 --- /dev/null +++ b/backend/tests/services/context/test_config.py @@ -0,0 +1,243 @@ +"""Tests for context management configuration.""" + +import os +from unittest.mock import patch + +import pytest + +from app.services.context.config import ( + ContextSettings, + get_context_settings, + get_default_settings, + reset_context_settings, +) + + +class TestContextSettings: + """Tests for ContextSettings.""" + + def test_default_values(self) -> None: + """Test default settings values.""" + settings = ContextSettings() + + # Budget defaults should sum to 1.0 + total = ( + settings.budget_system + + settings.budget_task + + settings.budget_knowledge + + settings.budget_conversation + + settings.budget_tools + + settings.budget_response + + settings.budget_buffer + ) + assert abs(total - 1.0) < 0.001 + + # Scoring weights should sum to 1.0 + weights_total = ( + settings.scoring_relevance_weight + + settings.scoring_recency_weight + + settings.scoring_priority_weight + ) + assert abs(weights_total - 1.0) < 0.001 + + def test_budget_allocation_values(self) -> None: + """Test specific budget allocation values.""" + settings = ContextSettings() + + assert settings.budget_system == 0.05 + assert settings.budget_task == 0.10 + assert settings.budget_knowledge == 0.40 + assert settings.budget_conversation == 0.20 + assert settings.budget_tools == 0.05 + assert settings.budget_response == 0.15 + assert settings.budget_buffer == 0.05 + + def test_scoring_weights(self) -> None: + """Test scoring weights.""" + settings = ContextSettings() + + assert settings.scoring_relevance_weight == 0.5 + assert settings.scoring_recency_weight == 0.3 + assert settings.scoring_priority_weight == 0.2 + + def test_cache_settings(self) -> None: + """Test cache settings.""" + settings = ContextSettings() + + assert settings.cache_enabled is True + assert settings.cache_ttl_seconds == 3600 + assert settings.cache_prefix == "ctx" + + def test_performance_settings(self) -> None: + """Test performance settings.""" + settings = ContextSettings() + + assert settings.max_assembly_time_ms == 100 + assert settings.parallel_scoring is True + assert settings.max_parallel_scores == 10 + + def test_get_budget_allocation(self) -> None: + """Test get_budget_allocation method.""" + settings = ContextSettings() + allocation = settings.get_budget_allocation() + + assert isinstance(allocation, dict) + assert "system" in allocation + assert "knowledge" in allocation + assert allocation["system"] == 0.05 + assert allocation["knowledge"] == 0.40 + + def test_get_scoring_weights(self) -> None: + """Test get_scoring_weights method.""" + settings = ContextSettings() + weights = settings.get_scoring_weights() + + assert isinstance(weights, dict) + assert "relevance" in weights + assert "recency" in weights + assert "priority" in weights + assert weights["relevance"] == 0.5 + + def test_to_dict(self) -> None: + """Test to_dict method.""" + settings = ContextSettings() + result = settings.to_dict() + + assert "budget" in result + assert "scoring" in result + assert "compression" in result + assert "cache" in result + assert "performance" in result + assert "knowledge" in result + assert "conversation" in result + + def test_budget_validation_fails_on_wrong_sum(self) -> None: + """Test that budget validation fails when sum != 1.0.""" + with pytest.raises(ValueError) as exc_info: + ContextSettings( + budget_system=0.5, + budget_task=0.5, + # Other budgets default to non-zero, so total > 1.0 + ) + + assert "sum to 1.0" in str(exc_info.value) + + def test_scoring_validation_fails_on_wrong_sum(self) -> None: + """Test that scoring validation fails when sum != 1.0.""" + with pytest.raises(ValueError) as exc_info: + ContextSettings( + scoring_relevance_weight=0.8, + scoring_recency_weight=0.8, + scoring_priority_weight=0.8, + ) + + assert "sum to 1.0" in str(exc_info.value) + + def test_search_type_validation(self) -> None: + """Test search type validation.""" + # Valid types should work + ContextSettings(knowledge_search_type="semantic") + ContextSettings(knowledge_search_type="keyword") + ContextSettings(knowledge_search_type="hybrid") + + # Invalid type should fail + with pytest.raises(ValueError): + ContextSettings(knowledge_search_type="invalid") + + def test_custom_budget_allocation(self) -> None: + """Test custom budget allocation that sums to 1.0.""" + settings = ContextSettings( + budget_system=0.10, + budget_task=0.10, + budget_knowledge=0.30, + budget_conversation=0.25, + budget_tools=0.05, + budget_response=0.15, + budget_buffer=0.05, + ) + + total = ( + settings.budget_system + + settings.budget_task + + settings.budget_knowledge + + settings.budget_conversation + + settings.budget_tools + + settings.budget_response + + settings.budget_buffer + ) + assert abs(total - 1.0) < 0.001 + + +class TestSettingsSingleton: + """Tests for settings singleton pattern.""" + + def setup_method(self) -> None: + """Reset settings before each test.""" + reset_context_settings() + + def teardown_method(self) -> None: + """Clean up after each test.""" + reset_context_settings() + + def test_get_context_settings_returns_instance(self) -> None: + """Test that get_context_settings returns a settings instance.""" + settings = get_context_settings() + assert isinstance(settings, ContextSettings) + + def test_get_context_settings_returns_same_instance(self) -> None: + """Test that get_context_settings returns the same instance.""" + settings1 = get_context_settings() + settings2 = get_context_settings() + assert settings1 is settings2 + + def test_reset_creates_new_instance(self) -> None: + """Test that reset creates a new instance.""" + settings1 = get_context_settings() + reset_context_settings() + settings2 = get_context_settings() + + # Should be different instances + assert settings1 is not settings2 + + def test_get_default_settings_cached(self) -> None: + """Test that get_default_settings is cached.""" + settings1 = get_default_settings() + settings2 = get_default_settings() + assert settings1 is settings2 + + +class TestEnvironmentOverrides: + """Tests for environment variable overrides.""" + + def setup_method(self) -> None: + """Reset settings before each test.""" + reset_context_settings() + + def teardown_method(self) -> None: + """Clean up after each test.""" + reset_context_settings() + # Clean up any env vars we set + for key in list(os.environ.keys()): + if key.startswith("CTX_"): + del os.environ[key] + + def test_env_override_cache_enabled(self) -> None: + """Test that CTX_CACHE_ENABLED env var works.""" + with patch.dict(os.environ, {"CTX_CACHE_ENABLED": "false"}): + reset_context_settings() + settings = ContextSettings() + assert settings.cache_enabled is False + + def test_env_override_cache_ttl(self) -> None: + """Test that CTX_CACHE_TTL_SECONDS env var works.""" + with patch.dict(os.environ, {"CTX_CACHE_TTL_SECONDS": "7200"}): + reset_context_settings() + settings = ContextSettings() + assert settings.cache_ttl_seconds == 7200 + + def test_env_override_max_assembly_time(self) -> None: + """Test that CTX_MAX_ASSEMBLY_TIME_MS env var works.""" + with patch.dict(os.environ, {"CTX_MAX_ASSEMBLY_TIME_MS": "200"}): + reset_context_settings() + settings = ContextSettings() + assert settings.max_assembly_time_ms == 200 diff --git a/backend/tests/services/context/test_exceptions.py b/backend/tests/services/context/test_exceptions.py new file mode 100644 index 0000000..f987f76 --- /dev/null +++ b/backend/tests/services/context/test_exceptions.py @@ -0,0 +1,252 @@ +"""Tests for context management exceptions.""" + +import pytest + +from app.services.context.exceptions import ( + AssemblyTimeoutError, + BudgetExceededError, + CacheError, + CompressionError, + ContextError, + ContextNotFoundError, + FormattingError, + InvalidContextError, + ScoringError, + TokenCountError, +) + + +class TestContextError: + """Tests for base ContextError.""" + + def test_basic_initialization(self) -> None: + """Test basic error initialization.""" + error = ContextError("Test error") + assert error.message == "Test error" + assert error.details == {} + assert str(error) == "Test error" + + def test_with_details(self) -> None: + """Test error with details.""" + error = ContextError("Test error", {"key": "value", "count": 42}) + assert error.details == {"key": "value", "count": 42} + + def test_to_dict(self) -> None: + """Test conversion to dictionary.""" + error = ContextError("Test error", {"key": "value"}) + result = error.to_dict() + + assert result["error_type"] == "ContextError" + assert result["message"] == "Test error" + assert result["details"] == {"key": "value"} + + def test_inheritance(self) -> None: + """Test that ContextError inherits from Exception.""" + error = ContextError("Test") + assert isinstance(error, Exception) + + +class TestBudgetExceededError: + """Tests for BudgetExceededError.""" + + def test_default_message(self) -> None: + """Test default error message.""" + error = BudgetExceededError() + assert "exceeded" in error.message.lower() + + def test_with_budget_info(self) -> None: + """Test with budget information.""" + error = BudgetExceededError( + allocated=1000, + requested=1500, + context_type="knowledge", + ) + + assert error.allocated == 1000 + assert error.requested == 1500 + assert error.context_type == "knowledge" + assert error.details["overage"] == 500 + + def test_to_dict_includes_budget_info(self) -> None: + """Test that to_dict includes budget info.""" + error = BudgetExceededError( + allocated=1000, + requested=1500, + ) + result = error.to_dict() + + assert result["details"]["allocated"] == 1000 + assert result["details"]["requested"] == 1500 + assert result["details"]["overage"] == 500 + + +class TestTokenCountError: + """Tests for TokenCountError.""" + + def test_basic_error(self) -> None: + """Test basic token count error.""" + error = TokenCountError() + assert "token" in error.message.lower() + + def test_with_model_info(self) -> None: + """Test with model information.""" + error = TokenCountError( + message="Failed to count", + model="claude-3-sonnet", + text_length=5000, + ) + + assert error.model == "claude-3-sonnet" + assert error.text_length == 5000 + assert error.details["model"] == "claude-3-sonnet" + + +class TestCompressionError: + """Tests for CompressionError.""" + + def test_basic_error(self) -> None: + """Test basic compression error.""" + error = CompressionError() + assert "compress" in error.message.lower() + + def test_with_token_info(self) -> None: + """Test with token information.""" + error = CompressionError( + original_tokens=2000, + target_tokens=1000, + achieved_tokens=1500, + ) + + assert error.original_tokens == 2000 + assert error.target_tokens == 1000 + assert error.achieved_tokens == 1500 + + +class TestAssemblyTimeoutError: + """Tests for AssemblyTimeoutError.""" + + def test_basic_error(self) -> None: + """Test basic timeout error.""" + error = AssemblyTimeoutError() + assert "timed out" in error.message.lower() + + def test_with_timing_info(self) -> None: + """Test with timing information.""" + error = AssemblyTimeoutError( + timeout_ms=100, + elapsed_ms=150.5, + stage="scoring", + ) + + assert error.timeout_ms == 100 + assert error.elapsed_ms == 150.5 + assert error.stage == "scoring" + assert error.details["stage"] == "scoring" + + +class TestScoringError: + """Tests for ScoringError.""" + + def test_basic_error(self) -> None: + """Test basic scoring error.""" + error = ScoringError() + assert "score" in error.message.lower() + + def test_with_scorer_info(self) -> None: + """Test with scorer information.""" + error = ScoringError( + scorer_type="relevance", + context_id="ctx-123", + ) + + assert error.scorer_type == "relevance" + assert error.context_id == "ctx-123" + + +class TestFormattingError: + """Tests for FormattingError.""" + + def test_basic_error(self) -> None: + """Test basic formatting error.""" + error = FormattingError() + assert "format" in error.message.lower() + + def test_with_model_info(self) -> None: + """Test with model information.""" + error = FormattingError( + model="claude-3-opus", + adapter="ClaudeAdapter", + ) + + assert error.model == "claude-3-opus" + assert error.adapter == "ClaudeAdapter" + + +class TestCacheError: + """Tests for CacheError.""" + + def test_basic_error(self) -> None: + """Test basic cache error.""" + error = CacheError() + assert "cache" in error.message.lower() + + def test_with_operation_info(self) -> None: + """Test with operation information.""" + error = CacheError( + operation="get", + cache_key="ctx:abc123", + ) + + assert error.operation == "get" + assert error.cache_key == "ctx:abc123" + + +class TestContextNotFoundError: + """Tests for ContextNotFoundError.""" + + def test_basic_error(self) -> None: + """Test basic not found error.""" + error = ContextNotFoundError() + assert "not found" in error.message.lower() + + def test_with_source_info(self) -> None: + """Test with source information.""" + error = ContextNotFoundError( + source="knowledge-base", + query="authentication flow", + ) + + assert error.source == "knowledge-base" + assert error.query == "authentication flow" + + +class TestInvalidContextError: + """Tests for InvalidContextError.""" + + def test_basic_error(self) -> None: + """Test basic invalid error.""" + error = InvalidContextError() + assert "invalid" in error.message.lower() + + def test_with_field_info(self) -> None: + """Test with field information.""" + error = InvalidContextError( + field="content", + value="", + reason="Content cannot be empty", + ) + + assert error.field == "content" + assert error.value == "" + assert error.reason == "Content cannot be empty" + + def test_value_type_only_in_details(self) -> None: + """Test that only value type is included in details (not actual value).""" + error = InvalidContextError( + field="api_key", + value="secret-key-here", + ) + + # Actual value should not be in details + assert "secret-key-here" not in str(error.details) + assert error.details["value_type"] == "str" diff --git a/backend/tests/services/context/test_types.py b/backend/tests/services/context/test_types.py new file mode 100644 index 0000000..0db53cc --- /dev/null +++ b/backend/tests/services/context/test_types.py @@ -0,0 +1,579 @@ +"""Tests for context types.""" + +import json +from datetime import UTC, datetime, timedelta + +import pytest + +from app.services.context.types import ( + AssembledContext, + BaseContext, + ContextPriority, + ContextType, + ConversationContext, + KnowledgeContext, + MessageRole, + SystemContext, + TaskComplexity, + TaskContext, + TaskStatus, + ToolContext, + ToolResultStatus, +) + + +class TestContextType: + """Tests for ContextType enum.""" + + def test_all_types_exist(self) -> None: + """Test that all expected context types exist.""" + assert ContextType.SYSTEM + assert ContextType.TASK + assert ContextType.KNOWLEDGE + assert ContextType.CONVERSATION + assert ContextType.TOOL + + def test_from_string_valid(self) -> None: + """Test from_string with valid values.""" + assert ContextType.from_string("system") == ContextType.SYSTEM + assert ContextType.from_string("KNOWLEDGE") == ContextType.KNOWLEDGE + assert ContextType.from_string("Task") == ContextType.TASK + + def test_from_string_invalid(self) -> None: + """Test from_string with invalid value.""" + with pytest.raises(ValueError) as exc_info: + ContextType.from_string("invalid") + assert "Invalid context type" in str(exc_info.value) + + +class TestContextPriority: + """Tests for ContextPriority enum.""" + + def test_priority_ordering(self) -> None: + """Test that priorities are ordered correctly.""" + assert ContextPriority.LOWEST.value < ContextPriority.LOW.value + assert ContextPriority.LOW.value < ContextPriority.NORMAL.value + assert ContextPriority.NORMAL.value < ContextPriority.HIGH.value + assert ContextPriority.HIGH.value < ContextPriority.HIGHEST.value + assert ContextPriority.HIGHEST.value < ContextPriority.CRITICAL.value + + def test_from_int(self) -> None: + """Test from_int conversion.""" + assert ContextPriority.from_int(0) == ContextPriority.LOWEST + assert ContextPriority.from_int(50) == ContextPriority.NORMAL + assert ContextPriority.from_int(100) == ContextPriority.HIGHEST + assert ContextPriority.from_int(200) == ContextPriority.CRITICAL + + def test_from_int_intermediate(self) -> None: + """Test from_int with intermediate values.""" + # Should return closest lower priority + assert ContextPriority.from_int(30) == ContextPriority.LOW + assert ContextPriority.from_int(60) == ContextPriority.NORMAL + + +class TestSystemContext: + """Tests for SystemContext.""" + + def test_creation(self) -> None: + """Test basic creation.""" + ctx = SystemContext( + content="You are a helpful assistant.", + source="system_prompt", + ) + + assert ctx.content == "You are a helpful assistant." + assert ctx.source == "system_prompt" + assert ctx.get_type() == ContextType.SYSTEM + + def test_default_high_priority(self) -> None: + """Test that system context defaults to high priority.""" + ctx = SystemContext(content="Test", source="test") + assert ctx.priority == ContextPriority.HIGH.value + + def test_create_persona(self) -> None: + """Test create_persona factory method.""" + ctx = SystemContext.create_persona( + name="Code Assistant", + description="A helpful coding assistant.", + capabilities=["Write code", "Debug"], + constraints=["Never expose secrets"], + ) + + assert "Code Assistant" in ctx.content + assert "helpful coding assistant" in ctx.content + assert "Write code" in ctx.content + assert "Never expose secrets" in ctx.content + assert ctx.priority == ContextPriority.HIGHEST.value + + def test_create_instructions(self) -> None: + """Test create_instructions factory method.""" + ctx = SystemContext.create_instructions( + ["Always be helpful", "Be concise"], + source="rules", + ) + + assert "Always be helpful" in ctx.content + assert "Be concise" in ctx.content + + def test_to_dict(self) -> None: + """Test serialization to dict.""" + ctx = SystemContext( + content="Test", + source="test", + role="assistant", + instructions_type="general", + ) + + data = ctx.to_dict() + assert data["role"] == "assistant" + assert data["instructions_type"] == "general" + assert data["type"] == "system" + + +class TestKnowledgeContext: + """Tests for KnowledgeContext.""" + + def test_creation(self) -> None: + """Test basic creation.""" + ctx = KnowledgeContext( + content="def authenticate(user): ...", + source="/src/auth.py", + collection="code", + file_type="python", + ) + + assert ctx.content == "def authenticate(user): ..." + assert ctx.source == "/src/auth.py" + assert ctx.collection == "code" + assert ctx.get_type() == ContextType.KNOWLEDGE + + def test_from_search_result(self) -> None: + """Test from_search_result factory method.""" + result = { + "content": "Test content", + "source_path": "/test/file.py", + "collection": "code", + "file_type": "python", + "chunk_index": 2, + "score": 0.85, + "id": "chunk-123", + } + + ctx = KnowledgeContext.from_search_result(result, "test query") + + assert ctx.content == "Test content" + assert ctx.source == "/test/file.py" + assert ctx.relevance_score == 0.85 + assert ctx.search_query == "test query" + + def test_from_search_results(self) -> None: + """Test from_search_results factory method.""" + results = [ + {"content": "Content 1", "source_path": "/a.py", "score": 0.9}, + {"content": "Content 2", "source_path": "/b.py", "score": 0.8}, + ] + + contexts = KnowledgeContext.from_search_results(results, "query") + + assert len(contexts) == 2 + assert contexts[0].relevance_score == 0.9 + assert contexts[1].source == "/b.py" + + def test_is_code(self) -> None: + """Test is_code method.""" + code_ctx = KnowledgeContext( + content="code", source="test", file_type="python" + ) + doc_ctx = KnowledgeContext( + content="docs", source="test", file_type="markdown" + ) + + assert code_ctx.is_code() is True + assert doc_ctx.is_code() is False + + def test_is_documentation(self) -> None: + """Test is_documentation method.""" + doc_ctx = KnowledgeContext( + content="docs", source="test", file_type="markdown" + ) + code_ctx = KnowledgeContext( + content="code", source="test", file_type="python" + ) + + assert doc_ctx.is_documentation() is True + assert code_ctx.is_documentation() is False + + +class TestConversationContext: + """Tests for ConversationContext.""" + + def test_creation(self) -> None: + """Test basic creation.""" + ctx = ConversationContext( + content="Hello, how can I help?", + source="conversation", + role=MessageRole.ASSISTANT, + turn_index=1, + ) + + assert ctx.content == "Hello, how can I help?" + assert ctx.role == MessageRole.ASSISTANT + assert ctx.get_type() == ContextType.CONVERSATION + + def test_from_message(self) -> None: + """Test from_message factory method.""" + ctx = ConversationContext.from_message( + content="What is Python?", + role="user", + turn_index=0, + ) + + assert ctx.content == "What is Python?" + assert ctx.role == MessageRole.USER + assert ctx.turn_index == 0 + + def test_from_history(self) -> None: + """Test from_history factory method.""" + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "Help me"}, + ] + + contexts = ConversationContext.from_history(messages) + + assert len(contexts) == 3 + assert contexts[0].role == MessageRole.USER + assert contexts[1].role == MessageRole.ASSISTANT + assert contexts[2].turn_index == 2 + + def test_is_user_message(self) -> None: + """Test is_user_message method.""" + user_ctx = ConversationContext( + content="test", source="test", role=MessageRole.USER + ) + assistant_ctx = ConversationContext( + content="test", source="test", role=MessageRole.ASSISTANT + ) + + assert user_ctx.is_user_message() is True + assert assistant_ctx.is_user_message() is False + + def test_format_for_prompt(self) -> None: + """Test format_for_prompt method.""" + ctx = ConversationContext.from_message( + content="What is 2+2?", + role="user", + ) + + formatted = ctx.format_for_prompt() + assert "User:" in formatted + assert "What is 2+2?" in formatted + + +class TestTaskContext: + """Tests for TaskContext.""" + + def test_creation(self) -> None: + """Test basic creation.""" + ctx = TaskContext( + content="Implement login feature", + source="task", + title="Login Feature", + ) + + assert ctx.content == "Implement login feature" + assert ctx.title == "Login Feature" + assert ctx.get_type() == ContextType.TASK + + def test_default_high_priority(self) -> None: + """Test that task context defaults to high priority.""" + ctx = TaskContext(content="Test", source="test") + assert ctx.priority == ContextPriority.HIGH.value + + def test_create_factory(self) -> None: + """Test create factory method.""" + ctx = TaskContext.create( + title="Add Auth", + description="Implement authentication", + acceptance_criteria=["Tests pass", "Code reviewed"], + constraints=["Use JWT"], + issue_id="123", + ) + + assert ctx.title == "Add Auth" + assert ctx.content == "Implement authentication" + assert len(ctx.acceptance_criteria) == 2 + assert "Use JWT" in ctx.constraints + assert ctx.status == TaskStatus.IN_PROGRESS + + def test_format_for_prompt(self) -> None: + """Test format_for_prompt method.""" + ctx = TaskContext.create( + title="Test Task", + description="Do something", + acceptance_criteria=["Works correctly"], + ) + + formatted = ctx.format_for_prompt() + assert "Task: Test Task" in formatted + assert "Do something" in formatted + assert "Works correctly" in formatted + + def test_status_checks(self) -> None: + """Test status check methods.""" + pending = TaskContext( + content="test", source="test", status=TaskStatus.PENDING + ) + completed = TaskContext( + content="test", source="test", status=TaskStatus.COMPLETED + ) + blocked = TaskContext( + content="test", source="test", status=TaskStatus.BLOCKED + ) + + assert pending.is_active() is True + assert completed.is_complete() is True + assert blocked.is_blocked() is True + + +class TestToolContext: + """Tests for ToolContext.""" + + def test_creation(self) -> None: + """Test basic creation.""" + ctx = ToolContext( + content="Tool result here", + source="tool:search", + tool_name="search", + ) + + assert ctx.tool_name == "search" + assert ctx.get_type() == ContextType.TOOL + + def test_from_tool_definition(self) -> None: + """Test from_tool_definition factory method.""" + ctx = ToolContext.from_tool_definition( + name="search_knowledge", + description="Search the knowledge base", + parameters={ + "query": {"type": "string", "required": True}, + "limit": {"type": "integer", "required": False}, + }, + server_name="knowledge-base", + ) + + assert ctx.tool_name == "search_knowledge" + assert "Search the knowledge base" in ctx.content + assert ctx.is_result is False + assert ctx.server_name == "knowledge-base" + + def test_from_tool_result(self) -> None: + """Test from_tool_result factory method.""" + ctx = ToolContext.from_tool_result( + tool_name="search", + result={"found": 5, "items": ["a", "b"]}, + status=ToolResultStatus.SUCCESS, + execution_time_ms=150.5, + ) + + assert ctx.tool_name == "search" + assert ctx.is_result is True + assert ctx.result_status == ToolResultStatus.SUCCESS + assert "found" in ctx.content + + def test_is_successful(self) -> None: + """Test is_successful method.""" + success = ToolContext.from_tool_result( + "test", "ok", ToolResultStatus.SUCCESS + ) + error = ToolContext.from_tool_result( + "test", "error", ToolResultStatus.ERROR + ) + + assert success.is_successful() is True + assert error.is_successful() is False + + def test_format_for_prompt(self) -> None: + """Test format_for_prompt method.""" + ctx = ToolContext.from_tool_result( + "search", + "Found 3 results", + ToolResultStatus.SUCCESS, + ) + + formatted = ctx.format_for_prompt() + assert "Tool Result" in formatted + assert "search" in formatted + assert "success" in formatted + + +class TestAssembledContext: + """Tests for AssembledContext.""" + + def test_creation(self) -> None: + """Test basic creation.""" + ctx = AssembledContext( + content="Assembled content here", + token_count=500, + contexts_included=5, + ) + + assert ctx.content == "Assembled content here" + assert ctx.token_count == 500 + assert ctx.contexts_included == 5 + + def test_budget_utilization(self) -> None: + """Test budget_utilization property.""" + ctx = AssembledContext( + content="test", + token_count=800, + contexts_included=5, + budget_total=1000, + budget_used=800, + ) + + assert ctx.budget_utilization == 0.8 + + def test_budget_utilization_zero_budget(self) -> None: + """Test budget_utilization with zero budget.""" + ctx = AssembledContext( + content="test", + token_count=0, + contexts_included=0, + budget_total=0, + budget_used=0, + ) + + assert ctx.budget_utilization == 0.0 + + def test_to_dict(self) -> None: + """Test to_dict method.""" + ctx = AssembledContext( + content="test", + token_count=100, + contexts_included=2, + assembly_time_ms=50.123, + ) + + data = ctx.to_dict() + assert data["content"] == "test" + assert data["token_count"] == 100 + assert data["assembly_time_ms"] == 50.12 # Rounded + + def test_to_json_and_from_json(self) -> None: + """Test JSON serialization round-trip.""" + original = AssembledContext( + content="Test content", + token_count=100, + contexts_included=3, + contexts_excluded=2, + assembly_time_ms=45.5, + budget_total=1000, + budget_used=100, + by_type={"system": 20, "knowledge": 80}, + cache_hit=True, + cache_key="abc123", + ) + + json_str = original.to_json() + restored = AssembledContext.from_json(json_str) + + assert restored.content == original.content + assert restored.token_count == original.token_count + assert restored.contexts_included == original.contexts_included + assert restored.cache_hit == original.cache_hit + assert restored.cache_key == original.cache_key + + +class TestBaseContextMethods: + """Tests for BaseContext methods.""" + + def test_get_age_seconds(self) -> None: + """Test get_age_seconds method.""" + old_time = datetime.now(UTC) - timedelta(hours=2) + ctx = SystemContext( + content="test", source="test", timestamp=old_time + ) + + age = ctx.get_age_seconds() + # Should be approximately 2 hours in seconds + assert 7100 < age < 7300 + + def test_get_age_hours(self) -> None: + """Test get_age_hours method.""" + old_time = datetime.now(UTC) - timedelta(hours=5) + ctx = SystemContext( + content="test", source="test", timestamp=old_time + ) + + age = ctx.get_age_hours() + assert 4.9 < age < 5.1 + + def test_is_stale(self) -> None: + """Test is_stale method.""" + old_time = datetime.now(UTC) - timedelta(days=10) + new_time = datetime.now(UTC) - timedelta(hours=1) + + old_ctx = SystemContext( + content="test", source="test", timestamp=old_time + ) + new_ctx = SystemContext( + content="test", source="test", timestamp=new_time + ) + + # Default max_age is 168 hours (7 days) + assert old_ctx.is_stale() is True + assert new_ctx.is_stale() is False + + def test_token_count_property(self) -> None: + """Test token_count property.""" + ctx = SystemContext(content="test", source="test") + + # Initially None + assert ctx.token_count is None + + # Can be set + ctx.token_count = 100 + assert ctx.token_count == 100 + + def test_score_property_clamping(self) -> None: + """Test that score is clamped to 0.0-1.0.""" + ctx = SystemContext(content="test", source="test") + + ctx.score = 1.5 + assert ctx.score == 1.0 + + ctx.score = -0.5 + assert ctx.score == 0.0 + + ctx.score = 0.75 + assert ctx.score == 0.75 + + def test_hash_and_equality(self) -> None: + """Test hash and equality based on ID.""" + ctx1 = SystemContext(content="test", source="test") + ctx2 = SystemContext(content="test", source="test") + ctx3 = SystemContext(content="test", source="test") + ctx3.id = ctx1.id # Same ID as ctx1 + + # Different IDs = not equal + assert ctx1 != ctx2 + + # Same ID = equal + assert ctx1 == ctx3 + + # Can be used in sets + context_set = {ctx1, ctx2, ctx3} + assert len(context_set) == 2 # ctx1 and ctx3 are same + + def test_truncate(self) -> None: + """Test truncate method.""" + long_content = "word " * 1000 # Long content + ctx = SystemContext(content=long_content, source="test") + ctx.token_count = 1000 + + truncated = ctx.truncate(100) + + assert len(truncated) < len(long_content) + assert "[truncated]" in truncated