forked from cardosofelipe/fast-next-template
feat(context): Phase 1 - Foundation types, config and exceptions (#79)
Implements the foundation for Context Management Engine: Types (backend/app/services/context/types/): - BaseContext: Abstract base with ID, content, priority, scoring - SystemContext: System prompts, personas, instructions - KnowledgeContext: RAG results from Knowledge Base MCP - ConversationContext: Chat history with role support - TaskContext: Task/issue context with acceptance criteria - ToolContext: Tool definitions and execution results - AssembledContext: Final assembled context result Configuration (config.py): - Token budget allocation (system 5%, task 10%, knowledge 40%, etc.) - Scoring weights (relevance 50%, recency 30%, priority 20%) - Cache settings (TTL, prefix) - Performance settings (max assembly time, parallel scoring) - Environment variable overrides with CTX_ prefix Exceptions (exceptions.py): - ContextError: Base exception - BudgetExceededError: Token budget violations - TokenCountError: Token counting failures - CompressionError: Compression failures - AssemblyTimeoutError: Assembly timeout - ScoringError, FormattingError, CacheError - ContextNotFoundError, InvalidContextError All 86 tests pass. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
105
backend/app/services/context/__init__.py
Normal file
105
backend/app/services/context/__init__.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
Context Management Engine
|
||||
|
||||
Sophisticated context assembly and optimization for LLM requests.
|
||||
Provides intelligent context selection, token budget management,
|
||||
and model-specific formatting.
|
||||
|
||||
Usage:
|
||||
from app.services.context import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
SystemContext,
|
||||
KnowledgeContext,
|
||||
ConversationContext,
|
||||
TaskContext,
|
||||
ToolContext,
|
||||
)
|
||||
|
||||
# Get settings
|
||||
settings = get_context_settings()
|
||||
|
||||
# Create context instances
|
||||
system_ctx = SystemContext.create_persona(
|
||||
name="Code Assistant",
|
||||
description="You are a helpful code assistant.",
|
||||
capabilities=["Write code", "Debug issues"],
|
||||
)
|
||||
"""
|
||||
|
||||
# Configuration
|
||||
from .config import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
get_default_settings,
|
||||
reset_context_settings,
|
||||
)
|
||||
|
||||
# Exceptions
|
||||
from .exceptions import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
CacheError,
|
||||
CompressionError,
|
||||
ContextError,
|
||||
ContextNotFoundError,
|
||||
FormattingError,
|
||||
InvalidContextError,
|
||||
ScoringError,
|
||||
TokenCountError,
|
||||
)
|
||||
|
||||
# Types
|
||||
from .types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Configuration
|
||||
"ContextSettings",
|
||||
"get_context_settings",
|
||||
"get_default_settings",
|
||||
"reset_context_settings",
|
||||
# Exceptions
|
||||
"AssemblyTimeoutError",
|
||||
"BudgetExceededError",
|
||||
"CacheError",
|
||||
"CompressionError",
|
||||
"ContextError",
|
||||
"ContextNotFoundError",
|
||||
"FormattingError",
|
||||
"InvalidContextError",
|
||||
"ScoringError",
|
||||
"TokenCountError",
|
||||
# Types - Base
|
||||
"AssembledContext",
|
||||
"BaseContext",
|
||||
"ContextPriority",
|
||||
"ContextType",
|
||||
# Types - Conversation
|
||||
"ConversationContext",
|
||||
"MessageRole",
|
||||
# Types - Knowledge
|
||||
"KnowledgeContext",
|
||||
# Types - System
|
||||
"SystemContext",
|
||||
# Types - Task
|
||||
"TaskComplexity",
|
||||
"TaskContext",
|
||||
"TaskStatus",
|
||||
# Types - Tool
|
||||
"ToolContext",
|
||||
"ToolResultStatus",
|
||||
]
|
||||
5
backend/app/services/context/adapters/__init__.py
Normal file
5
backend/app/services/context/adapters/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Model Adapters Module.
|
||||
|
||||
Provides model-specific context formatting.
|
||||
"""
|
||||
5
backend/app/services/context/assembly/__init__.py
Normal file
5
backend/app/services/context/assembly/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Context Assembly Module.
|
||||
|
||||
Provides the assembly pipeline and formatting.
|
||||
"""
|
||||
5
backend/app/services/context/budget/__init__.py
Normal file
5
backend/app/services/context/budget/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Token Budget Management Module.
|
||||
|
||||
Provides token counting and budget allocation.
|
||||
"""
|
||||
5
backend/app/services/context/cache/__init__.py
vendored
Normal file
5
backend/app/services/context/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Context Cache Module.
|
||||
|
||||
Provides Redis-based caching for assembled contexts.
|
||||
"""
|
||||
5
backend/app/services/context/compression/__init__.py
Normal file
5
backend/app/services/context/compression/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Context Compression Module.
|
||||
|
||||
Provides truncation and compression strategies.
|
||||
"""
|
||||
328
backend/app/services/context/config.py
Normal file
328
backend/app/services/context/config.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Context Management Engine Configuration.
|
||||
|
||||
Provides Pydantic settings for context assembly,
|
||||
token budget allocation, and caching.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ContextSettings(BaseSettings):
|
||||
"""
|
||||
Configuration for the Context Management Engine.
|
||||
|
||||
All settings can be overridden via environment variables
|
||||
with the CTX_ prefix.
|
||||
"""
|
||||
|
||||
# Budget allocation percentages (must sum to 1.0)
|
||||
budget_system: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for system prompts (5%)",
|
||||
)
|
||||
budget_task: float = Field(
|
||||
default=0.10,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for task context (10%)",
|
||||
)
|
||||
budget_knowledge: float = Field(
|
||||
default=0.40,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for RAG/knowledge (40%)",
|
||||
)
|
||||
budget_conversation: float = Field(
|
||||
default=0.20,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for conversation history (20%)",
|
||||
)
|
||||
budget_tools: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage of budget for tool descriptions (5%)",
|
||||
)
|
||||
budget_response: float = Field(
|
||||
default=0.15,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage reserved for response (15%)",
|
||||
)
|
||||
budget_buffer: float = Field(
|
||||
default=0.05,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Percentage buffer for safety margin (5%)",
|
||||
)
|
||||
|
||||
# Scoring weights
|
||||
scoring_relevance_weight: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for relevance scoring",
|
||||
)
|
||||
scoring_recency_weight: float = Field(
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for recency scoring",
|
||||
)
|
||||
scoring_priority_weight: float = Field(
|
||||
default=0.2,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight for priority scoring",
|
||||
)
|
||||
|
||||
# Recency decay settings
|
||||
recency_decay_hours: float = Field(
|
||||
default=24.0,
|
||||
gt=0.0,
|
||||
description="Hours until recency score decays to 50%",
|
||||
)
|
||||
recency_max_age_hours: float = Field(
|
||||
default=168.0,
|
||||
gt=0.0,
|
||||
description="Hours until context is considered stale (7 days)",
|
||||
)
|
||||
|
||||
# Compression settings
|
||||
compression_threshold: float = Field(
|
||||
default=0.8,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Compress when budget usage exceeds this percentage",
|
||||
)
|
||||
truncation_suffix: str = Field(
|
||||
default="... [truncated]",
|
||||
description="Suffix to add when truncating content",
|
||||
)
|
||||
summary_model_group: str = Field(
|
||||
default="fast",
|
||||
description="Model group to use for summarization",
|
||||
)
|
||||
|
||||
# Caching settings
|
||||
cache_enabled: bool = Field(
|
||||
default=True,
|
||||
description="Enable Redis caching for assembled contexts",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=3600,
|
||||
ge=60,
|
||||
le=86400,
|
||||
description="Cache TTL in seconds (1 hour default, max 24 hours)",
|
||||
)
|
||||
cache_prefix: str = Field(
|
||||
default="ctx",
|
||||
description="Redis key prefix for context cache",
|
||||
)
|
||||
|
||||
# Performance settings
|
||||
max_assembly_time_ms: int = Field(
|
||||
default=100,
|
||||
ge=10,
|
||||
le=5000,
|
||||
description="Maximum time for context assembly in milliseconds",
|
||||
)
|
||||
parallel_scoring: bool = Field(
|
||||
default=True,
|
||||
description="Score contexts in parallel for better performance",
|
||||
)
|
||||
max_parallel_scores: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Maximum number of contexts to score in parallel",
|
||||
)
|
||||
|
||||
# Knowledge retrieval settings
|
||||
knowledge_search_type: str = Field(
|
||||
default="hybrid",
|
||||
description="Default search type for knowledge retrieval",
|
||||
)
|
||||
knowledge_max_results: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=50,
|
||||
description="Maximum knowledge chunks to retrieve",
|
||||
)
|
||||
knowledge_min_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Minimum relevance score for knowledge",
|
||||
)
|
||||
|
||||
# Conversation history settings
|
||||
conversation_max_turns: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum conversation turns to include",
|
||||
)
|
||||
conversation_recent_priority: bool = Field(
|
||||
default=True,
|
||||
description="Prioritize recent conversation turns",
|
||||
)
|
||||
|
||||
@field_validator("knowledge_search_type")
|
||||
@classmethod
|
||||
def validate_search_type(cls, v: str) -> str:
|
||||
"""Validate search type is valid."""
|
||||
valid_types = {"semantic", "keyword", "hybrid"}
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"search_type must be one of: {valid_types}")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_budget_allocation(self) -> "ContextSettings":
|
||||
"""Validate that budget percentages sum to 1.0."""
|
||||
total = (
|
||||
self.budget_system
|
||||
+ self.budget_task
|
||||
+ self.budget_knowledge
|
||||
+ self.budget_conversation
|
||||
+ self.budget_tools
|
||||
+ self.budget_response
|
||||
+ self.budget_buffer
|
||||
)
|
||||
# Allow small floating point error
|
||||
if abs(total - 1.0) > 0.001:
|
||||
raise ValueError(
|
||||
f"Budget percentages must sum to 1.0, got {total:.3f}. "
|
||||
f"Current allocation: system={self.budget_system}, task={self.budget_task}, "
|
||||
f"knowledge={self.budget_knowledge}, conversation={self.budget_conversation}, "
|
||||
f"tools={self.budget_tools}, response={self.budget_response}, buffer={self.budget_buffer}"
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_scoring_weights(self) -> "ContextSettings":
|
||||
"""Validate that scoring weights sum to 1.0."""
|
||||
total = (
|
||||
self.scoring_relevance_weight
|
||||
+ self.scoring_recency_weight
|
||||
+ self.scoring_priority_weight
|
||||
)
|
||||
# Allow small floating point error
|
||||
if abs(total - 1.0) > 0.001:
|
||||
raise ValueError(
|
||||
f"Scoring weights must sum to 1.0, got {total:.3f}. "
|
||||
f"Current weights: relevance={self.scoring_relevance_weight}, "
|
||||
f"recency={self.scoring_recency_weight}, priority={self.scoring_priority_weight}"
|
||||
)
|
||||
return self
|
||||
|
||||
def get_budget_allocation(self) -> dict[str, float]:
|
||||
"""Get budget allocation as a dictionary."""
|
||||
return {
|
||||
"system": self.budget_system,
|
||||
"task": self.budget_task,
|
||||
"knowledge": self.budget_knowledge,
|
||||
"conversation": self.budget_conversation,
|
||||
"tools": self.budget_tools,
|
||||
"response": self.budget_response,
|
||||
"buffer": self.budget_buffer,
|
||||
}
|
||||
|
||||
def get_scoring_weights(self) -> dict[str, float]:
|
||||
"""Get scoring weights as a dictionary."""
|
||||
return {
|
||||
"relevance": self.scoring_relevance_weight,
|
||||
"recency": self.scoring_recency_weight,
|
||||
"priority": self.scoring_priority_weight,
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert settings to dictionary for logging/debugging."""
|
||||
return {
|
||||
"budget": self.get_budget_allocation(),
|
||||
"scoring": self.get_scoring_weights(),
|
||||
"compression": {
|
||||
"threshold": self.compression_threshold,
|
||||
"summary_model_group": self.summary_model_group,
|
||||
},
|
||||
"cache": {
|
||||
"enabled": self.cache_enabled,
|
||||
"ttl_seconds": self.cache_ttl_seconds,
|
||||
"prefix": self.cache_prefix,
|
||||
},
|
||||
"performance": {
|
||||
"max_assembly_time_ms": self.max_assembly_time_ms,
|
||||
"parallel_scoring": self.parallel_scoring,
|
||||
"max_parallel_scores": self.max_parallel_scores,
|
||||
},
|
||||
"knowledge": {
|
||||
"search_type": self.knowledge_search_type,
|
||||
"max_results": self.knowledge_max_results,
|
||||
"min_score": self.knowledge_min_score,
|
||||
},
|
||||
"conversation": {
|
||||
"max_turns": self.conversation_max_turns,
|
||||
"recent_priority": self.conversation_recent_priority,
|
||||
},
|
||||
}
|
||||
|
||||
model_config = {
|
||||
"env_prefix": "CTX_",
|
||||
"env_file": "../.env",
|
||||
"env_file_encoding": "utf-8",
|
||||
"case_sensitive": False,
|
||||
"extra": "ignore",
|
||||
}
|
||||
|
||||
|
||||
# Thread-safe singleton pattern
|
||||
_settings: ContextSettings | None = None
|
||||
_settings_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_context_settings() -> ContextSettings:
|
||||
"""
|
||||
Get the global ContextSettings instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
|
||||
Returns:
|
||||
ContextSettings instance
|
||||
"""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
with _settings_lock:
|
||||
if _settings is None:
|
||||
_settings = ContextSettings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_context_settings() -> None:
|
||||
"""
|
||||
Reset the global settings instance.
|
||||
|
||||
Primarily used for testing.
|
||||
"""
|
||||
global _settings
|
||||
with _settings_lock:
|
||||
_settings = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_default_settings() -> ContextSettings:
|
||||
"""
|
||||
Get default settings (cached).
|
||||
|
||||
Use this for read-only access to defaults.
|
||||
For mutable access, use get_context_settings().
|
||||
"""
|
||||
return ContextSettings()
|
||||
354
backend/app/services/context/exceptions.py
Normal file
354
backend/app/services/context/exceptions.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
Context Management Engine Exceptions.
|
||||
|
||||
Provides a hierarchy of exceptions for context assembly,
|
||||
token budget management, and related operations.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ContextError(Exception):
|
||||
"""
|
||||
Base exception for all context management errors.
|
||||
|
||||
All context-related exceptions should inherit from this class
|
||||
to allow for catch-all handling when needed.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, details: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Initialize context error.
|
||||
|
||||
Args:
|
||||
message: Human-readable error message
|
||||
details: Optional dict with additional error context
|
||||
"""
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(message)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert exception to dictionary for logging/serialization."""
|
||||
return {
|
||||
"error_type": self.__class__.__name__,
|
||||
"message": self.message,
|
||||
"details": self.details,
|
||||
}
|
||||
|
||||
|
||||
class BudgetExceededError(ContextError):
|
||||
"""
|
||||
Raised when token budget is exceeded.
|
||||
|
||||
This occurs when the assembled context would exceed the
|
||||
allocated token budget for a specific context type or total.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Token budget exceeded",
|
||||
allocated: int = 0,
|
||||
requested: int = 0,
|
||||
context_type: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize budget exceeded error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
allocated: Tokens allocated for this context type
|
||||
requested: Tokens requested
|
||||
context_type: Type of context that exceeded budget
|
||||
"""
|
||||
details = {
|
||||
"allocated": allocated,
|
||||
"requested": requested,
|
||||
"overage": requested - allocated,
|
||||
}
|
||||
if context_type:
|
||||
details["context_type"] = context_type
|
||||
|
||||
super().__init__(message, details)
|
||||
self.allocated = allocated
|
||||
self.requested = requested
|
||||
self.context_type = context_type
|
||||
|
||||
|
||||
class TokenCountError(ContextError):
|
||||
"""
|
||||
Raised when token counting fails.
|
||||
|
||||
This typically occurs when the LLM Gateway token counting
|
||||
service is unavailable or returns an error.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to count tokens",
|
||||
model: str | None = None,
|
||||
text_length: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize token count error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
model: Model for which counting was attempted
|
||||
text_length: Length of text that failed to count
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if model:
|
||||
details["model"] = model
|
||||
if text_length is not None:
|
||||
details["text_length"] = text_length
|
||||
|
||||
super().__init__(message, details)
|
||||
self.model = model
|
||||
self.text_length = text_length
|
||||
|
||||
|
||||
class CompressionError(ContextError):
|
||||
"""
|
||||
Raised when context compression fails.
|
||||
|
||||
This can occur when summarization or truncation cannot
|
||||
reduce content to fit within the budget.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to compress context",
|
||||
original_tokens: int | None = None,
|
||||
target_tokens: int | None = None,
|
||||
achieved_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize compression error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
original_tokens: Tokens before compression
|
||||
target_tokens: Target token count
|
||||
achieved_tokens: Tokens achieved after compression attempt
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if original_tokens is not None:
|
||||
details["original_tokens"] = original_tokens
|
||||
if target_tokens is not None:
|
||||
details["target_tokens"] = target_tokens
|
||||
if achieved_tokens is not None:
|
||||
details["achieved_tokens"] = achieved_tokens
|
||||
|
||||
super().__init__(message, details)
|
||||
self.original_tokens = original_tokens
|
||||
self.target_tokens = target_tokens
|
||||
self.achieved_tokens = achieved_tokens
|
||||
|
||||
|
||||
class AssemblyTimeoutError(ContextError):
|
||||
"""
|
||||
Raised when context assembly exceeds time limit.
|
||||
|
||||
Context assembly must complete within a configurable
|
||||
time limit to maintain responsiveness.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Context assembly timed out",
|
||||
timeout_ms: int = 0,
|
||||
elapsed_ms: float = 0.0,
|
||||
stage: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize assembly timeout error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
timeout_ms: Configured timeout in milliseconds
|
||||
elapsed_ms: Actual elapsed time in milliseconds
|
||||
stage: Pipeline stage where timeout occurred
|
||||
"""
|
||||
details = {
|
||||
"timeout_ms": timeout_ms,
|
||||
"elapsed_ms": round(elapsed_ms, 2),
|
||||
}
|
||||
if stage:
|
||||
details["stage"] = stage
|
||||
|
||||
super().__init__(message, details)
|
||||
self.timeout_ms = timeout_ms
|
||||
self.elapsed_ms = elapsed_ms
|
||||
self.stage = stage
|
||||
|
||||
|
||||
class ScoringError(ContextError):
|
||||
"""
|
||||
Raised when context scoring fails.
|
||||
|
||||
This occurs when relevance, recency, or priority scoring
|
||||
encounters an error.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to score context",
|
||||
scorer_type: str | None = None,
|
||||
context_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize scoring error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
scorer_type: Type of scorer that failed
|
||||
context_id: ID of context being scored
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if scorer_type:
|
||||
details["scorer_type"] = scorer_type
|
||||
if context_id:
|
||||
details["context_id"] = context_id
|
||||
|
||||
super().__init__(message, details)
|
||||
self.scorer_type = scorer_type
|
||||
self.context_id = context_id
|
||||
|
||||
|
||||
class FormattingError(ContextError):
|
||||
"""
|
||||
Raised when context formatting fails.
|
||||
|
||||
This occurs when converting assembled context to
|
||||
model-specific format fails.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Failed to format context",
|
||||
model: str | None = None,
|
||||
adapter: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize formatting error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
model: Target model
|
||||
adapter: Adapter that failed
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if model:
|
||||
details["model"] = model
|
||||
if adapter:
|
||||
details["adapter"] = adapter
|
||||
|
||||
super().__init__(message, details)
|
||||
self.model = model
|
||||
self.adapter = adapter
|
||||
|
||||
|
||||
class CacheError(ContextError):
|
||||
"""
|
||||
Raised when cache operations fail.
|
||||
|
||||
This is typically non-fatal and should be handled
|
||||
gracefully by falling back to recomputation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Cache operation failed",
|
||||
operation: str | None = None,
|
||||
cache_key: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize cache error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
operation: Cache operation that failed (get, set, delete)
|
||||
cache_key: Key involved in the failed operation
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if operation:
|
||||
details["operation"] = operation
|
||||
if cache_key:
|
||||
details["cache_key"] = cache_key
|
||||
|
||||
super().__init__(message, details)
|
||||
self.operation = operation
|
||||
self.cache_key = cache_key
|
||||
|
||||
|
||||
class ContextNotFoundError(ContextError):
|
||||
"""
|
||||
Raised when expected context is not found.
|
||||
|
||||
This occurs when required context sources return
|
||||
no results or are unavailable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Required context not found",
|
||||
source: str | None = None,
|
||||
query: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize context not found error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
source: Source that returned no results
|
||||
query: Query used to search
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if source:
|
||||
details["source"] = source
|
||||
if query:
|
||||
details["query"] = query
|
||||
|
||||
super().__init__(message, details)
|
||||
self.source = source
|
||||
self.query = query
|
||||
|
||||
|
||||
class InvalidContextError(ContextError):
|
||||
"""
|
||||
Raised when context data is invalid.
|
||||
|
||||
This occurs when context content or metadata
|
||||
fails validation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "Invalid context data",
|
||||
field: str | None = None,
|
||||
value: Any | None = None,
|
||||
reason: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize invalid context error.
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
field: Field that is invalid
|
||||
value: Invalid value (may be redacted for security)
|
||||
reason: Reason for invalidity
|
||||
"""
|
||||
details: dict[str, Any] = {}
|
||||
if field:
|
||||
details["field"] = field
|
||||
if value is not None:
|
||||
# Avoid logging potentially sensitive values
|
||||
details["value_type"] = type(value).__name__
|
||||
if reason:
|
||||
details["reason"] = reason
|
||||
|
||||
super().__init__(message, details)
|
||||
self.field = field
|
||||
self.value = value
|
||||
self.reason = reason
|
||||
5
backend/app/services/context/prioritization/__init__.py
Normal file
5
backend/app/services/context/prioritization/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Context Prioritization Module.
|
||||
|
||||
Provides context ranking and selection.
|
||||
"""
|
||||
5
backend/app/services/context/scoring/__init__.py
Normal file
5
backend/app/services/context/scoring/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Context Scoring Module.
|
||||
|
||||
Provides relevance, recency, and composite scoring.
|
||||
"""
|
||||
49
backend/app/services/context/types/__init__.py
Normal file
49
backend/app/services/context/types/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Context Types Module.
|
||||
|
||||
Provides all context types used in the Context Management Engine.
|
||||
"""
|
||||
|
||||
from .base import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
)
|
||||
from .conversation import (
|
||||
ConversationContext,
|
||||
MessageRole,
|
||||
)
|
||||
from .knowledge import KnowledgeContext
|
||||
from .system import SystemContext
|
||||
from .task import (
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
)
|
||||
from .tool import (
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base types
|
||||
"AssembledContext",
|
||||
"BaseContext",
|
||||
"ContextPriority",
|
||||
"ContextType",
|
||||
# Conversation
|
||||
"ConversationContext",
|
||||
"MessageRole",
|
||||
# Knowledge
|
||||
"KnowledgeContext",
|
||||
# System
|
||||
"SystemContext",
|
||||
# Task
|
||||
"TaskComplexity",
|
||||
"TaskContext",
|
||||
"TaskStatus",
|
||||
# Tool
|
||||
"ToolContext",
|
||||
"ToolResultStatus",
|
||||
]
|
||||
320
backend/app/services/context/types/base.py
Normal file
320
backend/app/services/context/types/base.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Base Context Types and Enums.
|
||||
|
||||
Provides the foundation for all context types used in
|
||||
the Context Management Engine.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
class ContextType(str, Enum):
|
||||
"""
|
||||
Types of context that can be assembled.
|
||||
|
||||
Each type has specific handling, formatting, and
|
||||
budget allocation rules.
|
||||
"""
|
||||
|
||||
SYSTEM = "system"
|
||||
TASK = "task"
|
||||
KNOWLEDGE = "knowledge"
|
||||
CONVERSATION = "conversation"
|
||||
TOOL = "tool"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "ContextType":
|
||||
"""
|
||||
Convert string to ContextType.
|
||||
|
||||
Args:
|
||||
value: String value
|
||||
|
||||
Returns:
|
||||
ContextType enum value
|
||||
|
||||
Raises:
|
||||
ValueError: If value is not a valid context type
|
||||
"""
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError:
|
||||
valid = ", ".join(t.value for t in cls)
|
||||
raise ValueError(f"Invalid context type '{value}'. Valid types: {valid}")
|
||||
|
||||
|
||||
class ContextPriority(int, Enum):
|
||||
"""
|
||||
Priority levels for context ordering.
|
||||
|
||||
Higher values indicate higher priority.
|
||||
"""
|
||||
|
||||
LOWEST = 0
|
||||
LOW = 25
|
||||
NORMAL = 50
|
||||
HIGH = 75
|
||||
HIGHEST = 100
|
||||
CRITICAL = 150 # Never omit
|
||||
|
||||
@classmethod
|
||||
def from_int(cls, value: int) -> "ContextPriority":
|
||||
"""
|
||||
Get closest priority level for an integer.
|
||||
|
||||
Args:
|
||||
value: Integer priority value
|
||||
|
||||
Returns:
|
||||
Closest ContextPriority enum value
|
||||
"""
|
||||
priorities = sorted(cls, key=lambda p: p.value)
|
||||
for priority in reversed(priorities):
|
||||
if value >= priority.value:
|
||||
return priority
|
||||
return cls.LOWEST
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class BaseContext(ABC):
|
||||
"""
|
||||
Abstract base class for all context types.
|
||||
|
||||
Provides common fields and methods for context handling,
|
||||
scoring, and serialization.
|
||||
"""
|
||||
|
||||
# Required fields
|
||||
content: str
|
||||
source: str
|
||||
|
||||
# Optional fields with defaults
|
||||
id: str = field(default_factory=lambda: str(uuid4()))
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
priority: int = field(default=ContextPriority.NORMAL.value)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Computed/cached fields
|
||||
_token_count: int | None = field(default=None, repr=False)
|
||||
_score: float | None = field(default=None, repr=False)
|
||||
|
||||
@property
|
||||
def token_count(self) -> int | None:
|
||||
"""Get cached token count (None if not counted yet)."""
|
||||
return self._token_count
|
||||
|
||||
@token_count.setter
|
||||
def token_count(self, value: int) -> None:
|
||||
"""Set token count."""
|
||||
self._token_count = value
|
||||
|
||||
@property
|
||||
def score(self) -> float | None:
|
||||
"""Get cached score (None if not scored yet)."""
|
||||
return self._score
|
||||
|
||||
@score.setter
|
||||
def score(self, value: float) -> None:
|
||||
"""Set score (clamped to 0.0-1.0)."""
|
||||
self._score = max(0.0, min(1.0, value))
|
||||
|
||||
@abstractmethod
|
||||
def get_type(self) -> ContextType:
|
||||
"""
|
||||
Get the type of this context.
|
||||
|
||||
Returns:
|
||||
ContextType enum value
|
||||
"""
|
||||
...
|
||||
|
||||
def get_age_seconds(self) -> float:
|
||||
"""
|
||||
Get age of context in seconds.
|
||||
|
||||
Returns:
|
||||
Age in seconds since creation
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
delta = now - self.timestamp
|
||||
return delta.total_seconds()
|
||||
|
||||
def get_age_hours(self) -> float:
|
||||
"""
|
||||
Get age of context in hours.
|
||||
|
||||
Returns:
|
||||
Age in hours since creation
|
||||
"""
|
||||
return self.get_age_seconds() / 3600
|
||||
|
||||
def is_stale(self, max_age_hours: float = 168.0) -> bool:
|
||||
"""
|
||||
Check if context is stale.
|
||||
|
||||
Args:
|
||||
max_age_hours: Maximum age before considered stale (default 7 days)
|
||||
|
||||
Returns:
|
||||
True if context is older than max_age_hours
|
||||
"""
|
||||
return self.get_age_hours() > max_age_hours
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert context to dictionary for serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary representation
|
||||
"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.get_type().value,
|
||||
"content": self.content,
|
||||
"source": self.source,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"priority": self.priority,
|
||||
"metadata": self.metadata,
|
||||
"token_count": self._token_count,
|
||||
"score": self._score,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "BaseContext":
|
||||
"""
|
||||
Create context from dictionary.
|
||||
|
||||
Note: Subclasses should override this to return correct type.
|
||||
|
||||
Args:
|
||||
data: Dictionary with context data
|
||||
|
||||
Returns:
|
||||
Context instance
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement from_dict")
|
||||
|
||||
def truncate(self, max_tokens: int, suffix: str = "... [truncated]") -> str:
|
||||
"""
|
||||
Truncate content to fit within token limit.
|
||||
|
||||
This is a rough estimation based on characters.
|
||||
For accurate truncation, use the TokenCalculator.
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum tokens allowed
|
||||
suffix: Suffix to append when truncated
|
||||
|
||||
Returns:
|
||||
Truncated content
|
||||
"""
|
||||
if self._token_count is None or self._token_count <= max_tokens:
|
||||
return self.content
|
||||
|
||||
# Rough estimation: 4 chars per token on average
|
||||
estimated_chars = max_tokens * 4
|
||||
suffix_chars = len(suffix)
|
||||
|
||||
if len(self.content) <= estimated_chars:
|
||||
return self.content
|
||||
|
||||
truncated = self.content[: estimated_chars - suffix_chars]
|
||||
# Try to break at word boundary
|
||||
last_space = truncated.rfind(" ")
|
||||
if last_space > estimated_chars * 0.8:
|
||||
truncated = truncated[:last_space]
|
||||
|
||||
return truncated + suffix
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Hash based on ID for set/dict usage."""
|
||||
return hash(self.id)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Equality based on ID."""
|
||||
if not isinstance(other, BaseContext):
|
||||
return False
|
||||
return self.id == other.id
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssembledContext:
|
||||
"""
|
||||
Result of context assembly.
|
||||
|
||||
Contains the final formatted context ready for LLM consumption,
|
||||
along with metadata about the assembly process.
|
||||
"""
|
||||
|
||||
# Main content
|
||||
content: str
|
||||
token_count: int
|
||||
|
||||
# Assembly metadata
|
||||
contexts_included: int
|
||||
contexts_excluded: int = 0
|
||||
assembly_time_ms: float = 0.0
|
||||
|
||||
# Budget tracking
|
||||
budget_total: int = 0
|
||||
budget_used: int = 0
|
||||
|
||||
# Context breakdown
|
||||
by_type: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# Cache info
|
||||
cache_hit: bool = False
|
||||
cache_key: str | None = None
|
||||
|
||||
@property
|
||||
def budget_utilization(self) -> float:
|
||||
"""Get budget utilization percentage."""
|
||||
if self.budget_total == 0:
|
||||
return 0.0
|
||||
return self.budget_used / self.budget_total
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"content": self.content,
|
||||
"token_count": self.token_count,
|
||||
"contexts_included": self.contexts_included,
|
||||
"contexts_excluded": self.contexts_excluded,
|
||||
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||
"budget_total": self.budget_total,
|
||||
"budget_used": self.budget_used,
|
||||
"budget_utilization": round(self.budget_utilization, 3),
|
||||
"by_type": self.by_type,
|
||||
"cache_hit": self.cache_hit,
|
||||
"cache_key": self.cache_key,
|
||||
}
|
||||
|
||||
def to_json(self) -> str:
|
||||
"""Convert to JSON string."""
|
||||
import json
|
||||
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "AssembledContext":
|
||||
"""Create from JSON string."""
|
||||
import json
|
||||
|
||||
data = json.loads(json_str)
|
||||
return cls(
|
||||
content=data["content"],
|
||||
token_count=data["token_count"],
|
||||
contexts_included=data["contexts_included"],
|
||||
contexts_excluded=data.get("contexts_excluded", 0),
|
||||
assembly_time_ms=data.get("assembly_time_ms", 0.0),
|
||||
budget_total=data.get("budget_total", 0),
|
||||
budget_used=data.get("budget_used", 0),
|
||||
by_type=data.get("by_type", {}),
|
||||
cache_hit=data.get("cache_hit", False),
|
||||
cache_key=data.get("cache_key"),
|
||||
)
|
||||
182
backend/app/services/context/types/conversation.py
Normal file
182
backend/app/services/context/types/conversation.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Conversation Context Type.
|
||||
|
||||
Represents conversation history for context continuity.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
"""Roles for conversation messages."""
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
TOOL = "tool"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "MessageRole":
|
||||
"""Convert string to MessageRole."""
|
||||
try:
|
||||
return cls(value.lower())
|
||||
except ValueError:
|
||||
# Default to user for unknown roles
|
||||
return cls.USER
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ConversationContext(BaseContext):
|
||||
"""
|
||||
Context from conversation history.
|
||||
|
||||
Represents a single turn in the conversation,
|
||||
including user messages, assistant responses,
|
||||
and tool results.
|
||||
"""
|
||||
|
||||
# Conversation-specific fields
|
||||
role: MessageRole = field(default=MessageRole.USER)
|
||||
turn_index: int = field(default=0)
|
||||
session_id: str | None = field(default=None)
|
||||
parent_message_id: str | None = field(default=None)
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return CONVERSATION context type."""
|
||||
return ContextType.CONVERSATION
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with conversation-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"role": self.role.value,
|
||||
"turn_index": self.turn_index,
|
||||
"session_id": self.session_id,
|
||||
"parent_message_id": self.parent_message_id,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ConversationContext":
|
||||
"""Create ConversationContext from dictionary."""
|
||||
role = data.get("role", "user")
|
||||
if isinstance(role, str):
|
||||
role = MessageRole.from_string(role)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "conversation"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
role=role,
|
||||
turn_index=data.get("turn_index", 0),
|
||||
session_id=data.get("session_id"),
|
||||
parent_message_id=data.get("parent_message_id"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_message(
|
||||
cls,
|
||||
content: str,
|
||||
role: str | MessageRole,
|
||||
turn_index: int = 0,
|
||||
session_id: str | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> "ConversationContext":
|
||||
"""
|
||||
Create ConversationContext from a message.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
role: Message role (user, assistant, system, tool)
|
||||
turn_index: Position in conversation
|
||||
session_id: Session identifier
|
||||
timestamp: Message timestamp
|
||||
|
||||
Returns:
|
||||
ConversationContext instance
|
||||
"""
|
||||
if isinstance(role, str):
|
||||
role = MessageRole.from_string(role)
|
||||
|
||||
# Recent messages have higher priority
|
||||
priority = ContextPriority.NORMAL.value
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source="conversation",
|
||||
role=role,
|
||||
turn_index=turn_index,
|
||||
session_id=session_id,
|
||||
timestamp=timestamp or datetime.now(UTC),
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_history(
|
||||
cls,
|
||||
messages: list[dict[str, Any]],
|
||||
session_id: str | None = None,
|
||||
) -> list["ConversationContext"]:
|
||||
"""
|
||||
Create multiple ConversationContexts from message history.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'
|
||||
session_id: Session identifier
|
||||
|
||||
Returns:
|
||||
List of ConversationContext instances
|
||||
"""
|
||||
contexts = []
|
||||
for i, msg in enumerate(messages):
|
||||
ctx = cls.from_message(
|
||||
content=msg.get("content", ""),
|
||||
role=msg.get("role", "user"),
|
||||
turn_index=i,
|
||||
session_id=session_id,
|
||||
timestamp=datetime.fromisoformat(msg["timestamp"])
|
||||
if "timestamp" in msg
|
||||
else None,
|
||||
)
|
||||
contexts.append(ctx)
|
||||
return contexts
|
||||
|
||||
def is_user_message(self) -> bool:
|
||||
"""Check if this is a user message."""
|
||||
return self.role == MessageRole.USER
|
||||
|
||||
def is_assistant_message(self) -> bool:
|
||||
"""Check if this is an assistant message."""
|
||||
return self.role == MessageRole.ASSISTANT
|
||||
|
||||
def is_tool_result(self) -> bool:
|
||||
"""Check if this is a tool result."""
|
||||
return self.role == MessageRole.TOOL
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format message for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted message string
|
||||
"""
|
||||
role_labels = {
|
||||
MessageRole.USER: "User",
|
||||
MessageRole.ASSISTANT: "Assistant",
|
||||
MessageRole.SYSTEM: "System",
|
||||
MessageRole.TOOL: "Tool Result",
|
||||
}
|
||||
label = role_labels.get(self.role, "Unknown")
|
||||
return f"{label}: {self.content}"
|
||||
143
backend/app/services/context/types/knowledge.py
Normal file
143
backend/app/services/context/types/knowledge.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Knowledge Context Type.
|
||||
|
||||
Represents RAG results from the Knowledge Base MCP server.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class KnowledgeContext(BaseContext):
|
||||
"""
|
||||
Context from knowledge base / RAG retrieval.
|
||||
|
||||
Knowledge context represents chunks retrieved from the
|
||||
Knowledge Base MCP server, including:
|
||||
- Code snippets
|
||||
- Documentation
|
||||
- Previous conversations
|
||||
- External knowledge
|
||||
|
||||
Each chunk includes relevance scoring from the search.
|
||||
"""
|
||||
|
||||
# Knowledge-specific fields
|
||||
collection: str = field(default="default")
|
||||
file_type: str | None = field(default=None)
|
||||
chunk_index: int = field(default=0)
|
||||
relevance_score: float = field(default=0.0)
|
||||
search_query: str = field(default="")
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return KNOWLEDGE context type."""
|
||||
return ContextType.KNOWLEDGE
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with knowledge-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"collection": self.collection,
|
||||
"file_type": self.file_type,
|
||||
"chunk_index": self.chunk_index,
|
||||
"relevance_score": self.relevance_score,
|
||||
"search_query": self.search_query,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "KnowledgeContext":
|
||||
"""Create KnowledgeContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
collection=data.get("collection", "default"),
|
||||
file_type=data.get("file_type"),
|
||||
chunk_index=data.get("chunk_index", 0),
|
||||
relevance_score=data.get("relevance_score", 0.0),
|
||||
search_query=data.get("search_query", ""),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_search_result(
|
||||
cls,
|
||||
result: dict[str, Any],
|
||||
query: str,
|
||||
) -> "KnowledgeContext":
|
||||
"""
|
||||
Create KnowledgeContext from a Knowledge Base search result.
|
||||
|
||||
Args:
|
||||
result: Search result from Knowledge Base MCP
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
KnowledgeContext instance
|
||||
"""
|
||||
return cls(
|
||||
content=result.get("content", ""),
|
||||
source=result.get("source_path", "unknown"),
|
||||
collection=result.get("collection", "default"),
|
||||
file_type=result.get("file_type"),
|
||||
chunk_index=result.get("chunk_index", 0),
|
||||
relevance_score=result.get("score", 0.0),
|
||||
search_query=query,
|
||||
metadata={
|
||||
"chunk_id": result.get("id"),
|
||||
"content_hash": result.get("content_hash"),
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_search_results(
|
||||
cls,
|
||||
results: list[dict[str, Any]],
|
||||
query: str,
|
||||
) -> list["KnowledgeContext"]:
|
||||
"""
|
||||
Create multiple KnowledgeContexts from search results.
|
||||
|
||||
Args:
|
||||
results: List of search results
|
||||
query: Search query used
|
||||
|
||||
Returns:
|
||||
List of KnowledgeContext instances
|
||||
"""
|
||||
return [cls.from_search_result(r, query) for r in results]
|
||||
|
||||
def is_code(self) -> bool:
|
||||
"""Check if this is code content."""
|
||||
code_types = {"python", "javascript", "typescript", "go", "rust", "java", "c", "cpp"}
|
||||
return self.file_type is not None and self.file_type.lower() in code_types
|
||||
|
||||
def is_documentation(self) -> bool:
|
||||
"""Check if this is documentation content."""
|
||||
doc_types = {"markdown", "rst", "txt", "md"}
|
||||
return self.file_type is not None and self.file_type.lower() in doc_types
|
||||
|
||||
def get_formatted_source(self) -> str:
|
||||
"""
|
||||
Get a formatted source string for display.
|
||||
|
||||
Returns:
|
||||
Formatted source string
|
||||
"""
|
||||
parts = [self.source]
|
||||
if self.file_type:
|
||||
parts.append(f"({self.file_type})")
|
||||
if self.collection != "default":
|
||||
parts.insert(0, f"[{self.collection}]")
|
||||
return " ".join(parts)
|
||||
138
backend/app/services/context/types/system.py
Normal file
138
backend/app/services/context/types/system.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
System Context Type.
|
||||
|
||||
Represents system prompts, instructions, and agent personas.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class SystemContext(BaseContext):
|
||||
"""
|
||||
Context for system prompts and instructions.
|
||||
|
||||
System context typically includes:
|
||||
- Agent persona and role definitions
|
||||
- Behavioral instructions
|
||||
- Safety guidelines
|
||||
- Output format requirements
|
||||
|
||||
System context is usually high priority and should
|
||||
rarely be truncated or omitted.
|
||||
"""
|
||||
|
||||
# System context specific fields
|
||||
role: str = field(default="assistant")
|
||||
instructions_type: str = field(default="general")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set high priority for system context."""
|
||||
# System context defaults to high priority
|
||||
if self.priority == ContextPriority.NORMAL.value:
|
||||
self.priority = ContextPriority.HIGH.value
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return SYSTEM context type."""
|
||||
return ContextType.SYSTEM
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with system-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"role": self.role,
|
||||
"instructions_type": self.instructions_type,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SystemContext":
|
||||
"""Create SystemContext from dictionary."""
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data["source"],
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
role=data.get("role", "assistant"),
|
||||
instructions_type=data.get("instructions_type", "general"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_persona(
|
||||
cls,
|
||||
name: str,
|
||||
description: str,
|
||||
capabilities: list[str] | None = None,
|
||||
constraints: list[str] | None = None,
|
||||
) -> "SystemContext":
|
||||
"""
|
||||
Create a persona system context.
|
||||
|
||||
Args:
|
||||
name: Agent name/role
|
||||
description: Role description
|
||||
capabilities: List of things the agent can do
|
||||
constraints: List of limitations
|
||||
|
||||
Returns:
|
||||
SystemContext with formatted persona
|
||||
"""
|
||||
parts = [f"You are {name}.", "", description]
|
||||
|
||||
if capabilities:
|
||||
parts.append("")
|
||||
parts.append("You can:")
|
||||
for cap in capabilities:
|
||||
parts.append(f"- {cap}")
|
||||
|
||||
if constraints:
|
||||
parts.append("")
|
||||
parts.append("You must not:")
|
||||
for constraint in constraints:
|
||||
parts.append(f"- {constraint}")
|
||||
|
||||
return cls(
|
||||
content="\n".join(parts),
|
||||
source="persona_builder",
|
||||
role=name.lower().replace(" ", "_"),
|
||||
instructions_type="persona",
|
||||
priority=ContextPriority.HIGHEST.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_instructions(
|
||||
cls,
|
||||
instructions: str | list[str],
|
||||
source: str = "instructions",
|
||||
) -> "SystemContext":
|
||||
"""
|
||||
Create an instructions system context.
|
||||
|
||||
Args:
|
||||
instructions: Instructions string or list of instruction strings
|
||||
source: Source identifier
|
||||
|
||||
Returns:
|
||||
SystemContext with instructions
|
||||
"""
|
||||
if isinstance(instructions, list):
|
||||
content = "\n".join(f"- {inst}" for inst in instructions)
|
||||
else:
|
||||
content = instructions
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=source,
|
||||
instructions_type="instructions",
|
||||
priority=ContextPriority.HIGH.value,
|
||||
)
|
||||
195
backend/app/services/context/types/task.py
Normal file
195
backend/app/services/context/types/task.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
Task Context Type.
|
||||
|
||||
Represents the current task or objective for the agent.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Status of a task."""
|
||||
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
BLOCKED = "blocked"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TaskComplexity(str, Enum):
|
||||
"""Complexity level of a task."""
|
||||
|
||||
TRIVIAL = "trivial"
|
||||
SIMPLE = "simple"
|
||||
MODERATE = "moderate"
|
||||
COMPLEX = "complex"
|
||||
VERY_COMPLEX = "very_complex"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class TaskContext(BaseContext):
|
||||
"""
|
||||
Context for the current task or objective.
|
||||
|
||||
Task context provides information about what the agent
|
||||
should accomplish, including:
|
||||
- Task description and goals
|
||||
- Acceptance criteria
|
||||
- Constraints and requirements
|
||||
- Related issue/ticket information
|
||||
"""
|
||||
|
||||
# Task-specific fields
|
||||
title: str = field(default="")
|
||||
status: TaskStatus = field(default=TaskStatus.PENDING)
|
||||
complexity: TaskComplexity = field(default=TaskComplexity.MODERATE)
|
||||
issue_id: str | None = field(default=None)
|
||||
project_id: str | None = field(default=None)
|
||||
acceptance_criteria: list[str] = field(default_factory=list)
|
||||
constraints: list[str] = field(default_factory=list)
|
||||
parent_task_id: str | None = field(default=None)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set high priority for task context."""
|
||||
# Task context defaults to high priority
|
||||
if self.priority == ContextPriority.NORMAL.value:
|
||||
self.priority = ContextPriority.HIGH.value
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return TASK context type."""
|
||||
return ContextType.TASK
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with task-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"title": self.title,
|
||||
"status": self.status.value,
|
||||
"complexity": self.complexity.value,
|
||||
"issue_id": self.issue_id,
|
||||
"project_id": self.project_id,
|
||||
"acceptance_criteria": self.acceptance_criteria,
|
||||
"constraints": self.constraints,
|
||||
"parent_task_id": self.parent_task_id,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "TaskContext":
|
||||
"""Create TaskContext from dictionary."""
|
||||
status = data.get("status", "pending")
|
||||
if isinstance(status, str):
|
||||
status = TaskStatus(status)
|
||||
|
||||
complexity = data.get("complexity", "moderate")
|
||||
if isinstance(complexity, str):
|
||||
complexity = TaskComplexity(complexity)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "task"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
title=data.get("title", ""),
|
||||
status=status,
|
||||
complexity=complexity,
|
||||
issue_id=data.get("issue_id"),
|
||||
project_id=data.get("project_id"),
|
||||
acceptance_criteria=data.get("acceptance_criteria", []),
|
||||
constraints=data.get("constraints", []),
|
||||
parent_task_id=data.get("parent_task_id"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
title: str,
|
||||
description: str,
|
||||
acceptance_criteria: list[str] | None = None,
|
||||
constraints: list[str] | None = None,
|
||||
issue_id: str | None = None,
|
||||
project_id: str | None = None,
|
||||
complexity: TaskComplexity | str = TaskComplexity.MODERATE,
|
||||
) -> "TaskContext":
|
||||
"""
|
||||
Create a task context.
|
||||
|
||||
Args:
|
||||
title: Task title
|
||||
description: Task description
|
||||
acceptance_criteria: List of acceptance criteria
|
||||
constraints: List of constraints
|
||||
issue_id: Related issue ID
|
||||
project_id: Project ID
|
||||
complexity: Task complexity
|
||||
|
||||
Returns:
|
||||
TaskContext instance
|
||||
"""
|
||||
if isinstance(complexity, str):
|
||||
complexity = TaskComplexity(complexity)
|
||||
|
||||
return cls(
|
||||
content=description,
|
||||
source=f"task:{issue_id}" if issue_id else "task",
|
||||
title=title,
|
||||
status=TaskStatus.IN_PROGRESS,
|
||||
complexity=complexity,
|
||||
issue_id=issue_id,
|
||||
project_id=project_id,
|
||||
acceptance_criteria=acceptance_criteria or [],
|
||||
constraints=constraints or [],
|
||||
)
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format task for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted task string
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if self.title:
|
||||
parts.append(f"Task: {self.title}")
|
||||
parts.append("")
|
||||
|
||||
parts.append(self.content)
|
||||
|
||||
if self.acceptance_criteria:
|
||||
parts.append("")
|
||||
parts.append("Acceptance Criteria:")
|
||||
for criterion in self.acceptance_criteria:
|
||||
parts.append(f"- {criterion}")
|
||||
|
||||
if self.constraints:
|
||||
parts.append("")
|
||||
parts.append("Constraints:")
|
||||
for constraint in self.constraints:
|
||||
parts.append(f"- {constraint}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if task is currently active."""
|
||||
return self.status in (TaskStatus.PENDING, TaskStatus.IN_PROGRESS)
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
"""Check if task is complete."""
|
||||
return self.status == TaskStatus.COMPLETED
|
||||
|
||||
def is_blocked(self) -> bool:
|
||||
"""Check if task is blocked."""
|
||||
return self.status == TaskStatus.BLOCKED
|
||||
207
backend/app/services/context/types/tool.py
Normal file
207
backend/app/services/context/types/tool.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Tool Context Type.
|
||||
|
||||
Represents available tools and recent tool execution results.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from .base import BaseContext, ContextPriority, ContextType
|
||||
|
||||
|
||||
class ToolResultStatus(str, Enum):
|
||||
"""Status of a tool execution result."""
|
||||
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
TIMEOUT = "timeout"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ToolContext(BaseContext):
|
||||
"""
|
||||
Context for tools and tool execution results.
|
||||
|
||||
Tool context includes:
|
||||
- Tool descriptions and parameters
|
||||
- Recent tool execution results
|
||||
- Tool availability information
|
||||
|
||||
This helps the LLM understand what tools are available
|
||||
and what results previous tool calls produced.
|
||||
"""
|
||||
|
||||
# Tool-specific fields
|
||||
tool_name: str = field(default="")
|
||||
tool_description: str = field(default="")
|
||||
is_result: bool = field(default=False)
|
||||
result_status: ToolResultStatus | None = field(default=None)
|
||||
execution_time_ms: float | None = field(default=None)
|
||||
parameters: dict[str, Any] = field(default_factory=dict)
|
||||
server_name: str | None = field(default=None)
|
||||
|
||||
def get_type(self) -> ContextType:
|
||||
"""Return TOOL context type."""
|
||||
return ContextType.TOOL
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary with tool-specific fields."""
|
||||
base = super().to_dict()
|
||||
base.update(
|
||||
{
|
||||
"tool_name": self.tool_name,
|
||||
"tool_description": self.tool_description,
|
||||
"is_result": self.is_result,
|
||||
"result_status": self.result_status.value if self.result_status else None,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"parameters": self.parameters,
|
||||
"server_name": self.server_name,
|
||||
}
|
||||
)
|
||||
return base
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "ToolContext":
|
||||
"""Create ToolContext from dictionary."""
|
||||
result_status = data.get("result_status")
|
||||
if isinstance(result_status, str):
|
||||
result_status = ToolResultStatus(result_status)
|
||||
|
||||
return cls(
|
||||
id=data.get("id", ""),
|
||||
content=data["content"],
|
||||
source=data.get("source", "tool"),
|
||||
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||
if isinstance(data.get("timestamp"), str)
|
||||
else data.get("timestamp", datetime.now(UTC)),
|
||||
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||
metadata=data.get("metadata", {}),
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_description=data.get("tool_description", ""),
|
||||
is_result=data.get("is_result", False),
|
||||
result_status=result_status,
|
||||
execution_time_ms=data.get("execution_time_ms"),
|
||||
parameters=data.get("parameters", {}),
|
||||
server_name=data.get("server_name"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_definition(
|
||||
cls,
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
server_name: str | None = None,
|
||||
) -> "ToolContext":
|
||||
"""
|
||||
Create a ToolContext from a tool definition.
|
||||
|
||||
Args:
|
||||
name: Tool name
|
||||
description: Tool description
|
||||
parameters: Tool parameter schema
|
||||
server_name: MCP server name
|
||||
|
||||
Returns:
|
||||
ToolContext instance
|
||||
"""
|
||||
# Format content as tool documentation
|
||||
content_parts = [f"Tool: {name}", "", description]
|
||||
|
||||
if parameters:
|
||||
content_parts.append("")
|
||||
content_parts.append("Parameters:")
|
||||
for param_name, param_info in parameters.items():
|
||||
param_type = param_info.get("type", "any")
|
||||
param_desc = param_info.get("description", "")
|
||||
required = param_info.get("required", False)
|
||||
req_marker = " (required)" if required else ""
|
||||
content_parts.append(f" - {param_name}: {param_type}{req_marker}")
|
||||
if param_desc:
|
||||
content_parts.append(f" {param_desc}")
|
||||
|
||||
return cls(
|
||||
content="\n".join(content_parts),
|
||||
source=f"tool:{server_name}:{name}" if server_name else f"tool:{name}",
|
||||
tool_name=name,
|
||||
tool_description=description,
|
||||
is_result=False,
|
||||
parameters=parameters or {},
|
||||
server_name=server_name,
|
||||
priority=ContextPriority.LOW.value,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_tool_result(
|
||||
cls,
|
||||
tool_name: str,
|
||||
result: Any,
|
||||
status: ToolResultStatus = ToolResultStatus.SUCCESS,
|
||||
execution_time_ms: float | None = None,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
server_name: str | None = None,
|
||||
) -> "ToolContext":
|
||||
"""
|
||||
Create a ToolContext from a tool execution result.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was executed
|
||||
result: Result content (will be converted to string)
|
||||
status: Execution status
|
||||
execution_time_ms: Execution time in milliseconds
|
||||
parameters: Parameters that were passed to the tool
|
||||
server_name: MCP server name
|
||||
|
||||
Returns:
|
||||
ToolContext instance
|
||||
"""
|
||||
# Convert result to string content
|
||||
if isinstance(result, str):
|
||||
content = result
|
||||
elif isinstance(result, dict):
|
||||
import json
|
||||
|
||||
try:
|
||||
content = json.dumps(result, indent=2)
|
||||
except (TypeError, ValueError):
|
||||
content = str(result)
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
return cls(
|
||||
content=content,
|
||||
source=f"tool_result:{server_name}:{tool_name}" if server_name else f"tool_result:{tool_name}",
|
||||
tool_name=tool_name,
|
||||
is_result=True,
|
||||
result_status=status,
|
||||
execution_time_ms=execution_time_ms,
|
||||
parameters=parameters or {},
|
||||
server_name=server_name,
|
||||
priority=ContextPriority.HIGH.value, # Recent results are high priority
|
||||
)
|
||||
|
||||
def is_successful(self) -> bool:
|
||||
"""Check if this is a successful tool result."""
|
||||
return self.is_result and self.result_status == ToolResultStatus.SUCCESS
|
||||
|
||||
def is_error(self) -> bool:
|
||||
"""Check if this is an error result."""
|
||||
return self.is_result and self.result_status == ToolResultStatus.ERROR
|
||||
|
||||
def format_for_prompt(self) -> str:
|
||||
"""
|
||||
Format tool context for inclusion in prompt.
|
||||
|
||||
Returns:
|
||||
Formatted tool string
|
||||
"""
|
||||
if self.is_result:
|
||||
status_str = self.result_status.value if self.result_status else "unknown"
|
||||
header = f"Tool Result ({self.tool_name}, {status_str}):"
|
||||
return f"{header}\n{self.content}"
|
||||
else:
|
||||
return self.content
|
||||
1
backend/tests/services/context/__init__.py
Normal file
1
backend/tests/services/context/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for Context Management Engine."""
|
||||
243
backend/tests/services/context/test_config.py
Normal file
243
backend/tests/services/context/test_config.py
Normal file
@@ -0,0 +1,243 @@
|
||||
"""Tests for context management configuration."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.config import (
|
||||
ContextSettings,
|
||||
get_context_settings,
|
||||
get_default_settings,
|
||||
reset_context_settings,
|
||||
)
|
||||
|
||||
|
||||
class TestContextSettings:
|
||||
"""Tests for ContextSettings."""
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default settings values."""
|
||||
settings = ContextSettings()
|
||||
|
||||
# Budget defaults should sum to 1.0
|
||||
total = (
|
||||
settings.budget_system
|
||||
+ settings.budget_task
|
||||
+ settings.budget_knowledge
|
||||
+ settings.budget_conversation
|
||||
+ settings.budget_tools
|
||||
+ settings.budget_response
|
||||
+ settings.budget_buffer
|
||||
)
|
||||
assert abs(total - 1.0) < 0.001
|
||||
|
||||
# Scoring weights should sum to 1.0
|
||||
weights_total = (
|
||||
settings.scoring_relevance_weight
|
||||
+ settings.scoring_recency_weight
|
||||
+ settings.scoring_priority_weight
|
||||
)
|
||||
assert abs(weights_total - 1.0) < 0.001
|
||||
|
||||
def test_budget_allocation_values(self) -> None:
|
||||
"""Test specific budget allocation values."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.budget_system == 0.05
|
||||
assert settings.budget_task == 0.10
|
||||
assert settings.budget_knowledge == 0.40
|
||||
assert settings.budget_conversation == 0.20
|
||||
assert settings.budget_tools == 0.05
|
||||
assert settings.budget_response == 0.15
|
||||
assert settings.budget_buffer == 0.05
|
||||
|
||||
def test_scoring_weights(self) -> None:
|
||||
"""Test scoring weights."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.scoring_relevance_weight == 0.5
|
||||
assert settings.scoring_recency_weight == 0.3
|
||||
assert settings.scoring_priority_weight == 0.2
|
||||
|
||||
def test_cache_settings(self) -> None:
|
||||
"""Test cache settings."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.cache_enabled is True
|
||||
assert settings.cache_ttl_seconds == 3600
|
||||
assert settings.cache_prefix == "ctx"
|
||||
|
||||
def test_performance_settings(self) -> None:
|
||||
"""Test performance settings."""
|
||||
settings = ContextSettings()
|
||||
|
||||
assert settings.max_assembly_time_ms == 100
|
||||
assert settings.parallel_scoring is True
|
||||
assert settings.max_parallel_scores == 10
|
||||
|
||||
def test_get_budget_allocation(self) -> None:
|
||||
"""Test get_budget_allocation method."""
|
||||
settings = ContextSettings()
|
||||
allocation = settings.get_budget_allocation()
|
||||
|
||||
assert isinstance(allocation, dict)
|
||||
assert "system" in allocation
|
||||
assert "knowledge" in allocation
|
||||
assert allocation["system"] == 0.05
|
||||
assert allocation["knowledge"] == 0.40
|
||||
|
||||
def test_get_scoring_weights(self) -> None:
|
||||
"""Test get_scoring_weights method."""
|
||||
settings = ContextSettings()
|
||||
weights = settings.get_scoring_weights()
|
||||
|
||||
assert isinstance(weights, dict)
|
||||
assert "relevance" in weights
|
||||
assert "recency" in weights
|
||||
assert "priority" in weights
|
||||
assert weights["relevance"] == 0.5
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test to_dict method."""
|
||||
settings = ContextSettings()
|
||||
result = settings.to_dict()
|
||||
|
||||
assert "budget" in result
|
||||
assert "scoring" in result
|
||||
assert "compression" in result
|
||||
assert "cache" in result
|
||||
assert "performance" in result
|
||||
assert "knowledge" in result
|
||||
assert "conversation" in result
|
||||
|
||||
def test_budget_validation_fails_on_wrong_sum(self) -> None:
|
||||
"""Test that budget validation fails when sum != 1.0."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ContextSettings(
|
||||
budget_system=0.5,
|
||||
budget_task=0.5,
|
||||
# Other budgets default to non-zero, so total > 1.0
|
||||
)
|
||||
|
||||
assert "sum to 1.0" in str(exc_info.value)
|
||||
|
||||
def test_scoring_validation_fails_on_wrong_sum(self) -> None:
|
||||
"""Test that scoring validation fails when sum != 1.0."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ContextSettings(
|
||||
scoring_relevance_weight=0.8,
|
||||
scoring_recency_weight=0.8,
|
||||
scoring_priority_weight=0.8,
|
||||
)
|
||||
|
||||
assert "sum to 1.0" in str(exc_info.value)
|
||||
|
||||
def test_search_type_validation(self) -> None:
|
||||
"""Test search type validation."""
|
||||
# Valid types should work
|
||||
ContextSettings(knowledge_search_type="semantic")
|
||||
ContextSettings(knowledge_search_type="keyword")
|
||||
ContextSettings(knowledge_search_type="hybrid")
|
||||
|
||||
# Invalid type should fail
|
||||
with pytest.raises(ValueError):
|
||||
ContextSettings(knowledge_search_type="invalid")
|
||||
|
||||
def test_custom_budget_allocation(self) -> None:
|
||||
"""Test custom budget allocation that sums to 1.0."""
|
||||
settings = ContextSettings(
|
||||
budget_system=0.10,
|
||||
budget_task=0.10,
|
||||
budget_knowledge=0.30,
|
||||
budget_conversation=0.25,
|
||||
budget_tools=0.05,
|
||||
budget_response=0.15,
|
||||
budget_buffer=0.05,
|
||||
)
|
||||
|
||||
total = (
|
||||
settings.budget_system
|
||||
+ settings.budget_task
|
||||
+ settings.budget_knowledge
|
||||
+ settings.budget_conversation
|
||||
+ settings.budget_tools
|
||||
+ settings.budget_response
|
||||
+ settings.budget_buffer
|
||||
)
|
||||
assert abs(total - 1.0) < 0.001
|
||||
|
||||
|
||||
class TestSettingsSingleton:
|
||||
"""Tests for settings singleton pattern."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Reset settings before each test."""
|
||||
reset_context_settings()
|
||||
|
||||
def teardown_method(self) -> None:
|
||||
"""Clean up after each test."""
|
||||
reset_context_settings()
|
||||
|
||||
def test_get_context_settings_returns_instance(self) -> None:
|
||||
"""Test that get_context_settings returns a settings instance."""
|
||||
settings = get_context_settings()
|
||||
assert isinstance(settings, ContextSettings)
|
||||
|
||||
def test_get_context_settings_returns_same_instance(self) -> None:
|
||||
"""Test that get_context_settings returns the same instance."""
|
||||
settings1 = get_context_settings()
|
||||
settings2 = get_context_settings()
|
||||
assert settings1 is settings2
|
||||
|
||||
def test_reset_creates_new_instance(self) -> None:
|
||||
"""Test that reset creates a new instance."""
|
||||
settings1 = get_context_settings()
|
||||
reset_context_settings()
|
||||
settings2 = get_context_settings()
|
||||
|
||||
# Should be different instances
|
||||
assert settings1 is not settings2
|
||||
|
||||
def test_get_default_settings_cached(self) -> None:
|
||||
"""Test that get_default_settings is cached."""
|
||||
settings1 = get_default_settings()
|
||||
settings2 = get_default_settings()
|
||||
assert settings1 is settings2
|
||||
|
||||
|
||||
class TestEnvironmentOverrides:
|
||||
"""Tests for environment variable overrides."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Reset settings before each test."""
|
||||
reset_context_settings()
|
||||
|
||||
def teardown_method(self) -> None:
|
||||
"""Clean up after each test."""
|
||||
reset_context_settings()
|
||||
# Clean up any env vars we set
|
||||
for key in list(os.environ.keys()):
|
||||
if key.startswith("CTX_"):
|
||||
del os.environ[key]
|
||||
|
||||
def test_env_override_cache_enabled(self) -> None:
|
||||
"""Test that CTX_CACHE_ENABLED env var works."""
|
||||
with patch.dict(os.environ, {"CTX_CACHE_ENABLED": "false"}):
|
||||
reset_context_settings()
|
||||
settings = ContextSettings()
|
||||
assert settings.cache_enabled is False
|
||||
|
||||
def test_env_override_cache_ttl(self) -> None:
|
||||
"""Test that CTX_CACHE_TTL_SECONDS env var works."""
|
||||
with patch.dict(os.environ, {"CTX_CACHE_TTL_SECONDS": "7200"}):
|
||||
reset_context_settings()
|
||||
settings = ContextSettings()
|
||||
assert settings.cache_ttl_seconds == 7200
|
||||
|
||||
def test_env_override_max_assembly_time(self) -> None:
|
||||
"""Test that CTX_MAX_ASSEMBLY_TIME_MS env var works."""
|
||||
with patch.dict(os.environ, {"CTX_MAX_ASSEMBLY_TIME_MS": "200"}):
|
||||
reset_context_settings()
|
||||
settings = ContextSettings()
|
||||
assert settings.max_assembly_time_ms == 200
|
||||
252
backend/tests/services/context/test_exceptions.py
Normal file
252
backend/tests/services/context/test_exceptions.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Tests for context management exceptions."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.exceptions import (
|
||||
AssemblyTimeoutError,
|
||||
BudgetExceededError,
|
||||
CacheError,
|
||||
CompressionError,
|
||||
ContextError,
|
||||
ContextNotFoundError,
|
||||
FormattingError,
|
||||
InvalidContextError,
|
||||
ScoringError,
|
||||
TokenCountError,
|
||||
)
|
||||
|
||||
|
||||
class TestContextError:
|
||||
"""Tests for base ContextError."""
|
||||
|
||||
def test_basic_initialization(self) -> None:
|
||||
"""Test basic error initialization."""
|
||||
error = ContextError("Test error")
|
||||
assert error.message == "Test error"
|
||||
assert error.details == {}
|
||||
assert str(error) == "Test error"
|
||||
|
||||
def test_with_details(self) -> None:
|
||||
"""Test error with details."""
|
||||
error = ContextError("Test error", {"key": "value", "count": 42})
|
||||
assert error.details == {"key": "value", "count": 42}
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test conversion to dictionary."""
|
||||
error = ContextError("Test error", {"key": "value"})
|
||||
result = error.to_dict()
|
||||
|
||||
assert result["error_type"] == "ContextError"
|
||||
assert result["message"] == "Test error"
|
||||
assert result["details"] == {"key": "value"}
|
||||
|
||||
def test_inheritance(self) -> None:
|
||||
"""Test that ContextError inherits from Exception."""
|
||||
error = ContextError("Test")
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
|
||||
class TestBudgetExceededError:
|
||||
"""Tests for BudgetExceededError."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""Test default error message."""
|
||||
error = BudgetExceededError()
|
||||
assert "exceeded" in error.message.lower()
|
||||
|
||||
def test_with_budget_info(self) -> None:
|
||||
"""Test with budget information."""
|
||||
error = BudgetExceededError(
|
||||
allocated=1000,
|
||||
requested=1500,
|
||||
context_type="knowledge",
|
||||
)
|
||||
|
||||
assert error.allocated == 1000
|
||||
assert error.requested == 1500
|
||||
assert error.context_type == "knowledge"
|
||||
assert error.details["overage"] == 500
|
||||
|
||||
def test_to_dict_includes_budget_info(self) -> None:
|
||||
"""Test that to_dict includes budget info."""
|
||||
error = BudgetExceededError(
|
||||
allocated=1000,
|
||||
requested=1500,
|
||||
)
|
||||
result = error.to_dict()
|
||||
|
||||
assert result["details"]["allocated"] == 1000
|
||||
assert result["details"]["requested"] == 1500
|
||||
assert result["details"]["overage"] == 500
|
||||
|
||||
|
||||
class TestTokenCountError:
|
||||
"""Tests for TokenCountError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic token count error."""
|
||||
error = TokenCountError()
|
||||
assert "token" in error.message.lower()
|
||||
|
||||
def test_with_model_info(self) -> None:
|
||||
"""Test with model information."""
|
||||
error = TokenCountError(
|
||||
message="Failed to count",
|
||||
model="claude-3-sonnet",
|
||||
text_length=5000,
|
||||
)
|
||||
|
||||
assert error.model == "claude-3-sonnet"
|
||||
assert error.text_length == 5000
|
||||
assert error.details["model"] == "claude-3-sonnet"
|
||||
|
||||
|
||||
class TestCompressionError:
|
||||
"""Tests for CompressionError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic compression error."""
|
||||
error = CompressionError()
|
||||
assert "compress" in error.message.lower()
|
||||
|
||||
def test_with_token_info(self) -> None:
|
||||
"""Test with token information."""
|
||||
error = CompressionError(
|
||||
original_tokens=2000,
|
||||
target_tokens=1000,
|
||||
achieved_tokens=1500,
|
||||
)
|
||||
|
||||
assert error.original_tokens == 2000
|
||||
assert error.target_tokens == 1000
|
||||
assert error.achieved_tokens == 1500
|
||||
|
||||
|
||||
class TestAssemblyTimeoutError:
|
||||
"""Tests for AssemblyTimeoutError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic timeout error."""
|
||||
error = AssemblyTimeoutError()
|
||||
assert "timed out" in error.message.lower()
|
||||
|
||||
def test_with_timing_info(self) -> None:
|
||||
"""Test with timing information."""
|
||||
error = AssemblyTimeoutError(
|
||||
timeout_ms=100,
|
||||
elapsed_ms=150.5,
|
||||
stage="scoring",
|
||||
)
|
||||
|
||||
assert error.timeout_ms == 100
|
||||
assert error.elapsed_ms == 150.5
|
||||
assert error.stage == "scoring"
|
||||
assert error.details["stage"] == "scoring"
|
||||
|
||||
|
||||
class TestScoringError:
|
||||
"""Tests for ScoringError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic scoring error."""
|
||||
error = ScoringError()
|
||||
assert "score" in error.message.lower()
|
||||
|
||||
def test_with_scorer_info(self) -> None:
|
||||
"""Test with scorer information."""
|
||||
error = ScoringError(
|
||||
scorer_type="relevance",
|
||||
context_id="ctx-123",
|
||||
)
|
||||
|
||||
assert error.scorer_type == "relevance"
|
||||
assert error.context_id == "ctx-123"
|
||||
|
||||
|
||||
class TestFormattingError:
|
||||
"""Tests for FormattingError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic formatting error."""
|
||||
error = FormattingError()
|
||||
assert "format" in error.message.lower()
|
||||
|
||||
def test_with_model_info(self) -> None:
|
||||
"""Test with model information."""
|
||||
error = FormattingError(
|
||||
model="claude-3-opus",
|
||||
adapter="ClaudeAdapter",
|
||||
)
|
||||
|
||||
assert error.model == "claude-3-opus"
|
||||
assert error.adapter == "ClaudeAdapter"
|
||||
|
||||
|
||||
class TestCacheError:
|
||||
"""Tests for CacheError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic cache error."""
|
||||
error = CacheError()
|
||||
assert "cache" in error.message.lower()
|
||||
|
||||
def test_with_operation_info(self) -> None:
|
||||
"""Test with operation information."""
|
||||
error = CacheError(
|
||||
operation="get",
|
||||
cache_key="ctx:abc123",
|
||||
)
|
||||
|
||||
assert error.operation == "get"
|
||||
assert error.cache_key == "ctx:abc123"
|
||||
|
||||
|
||||
class TestContextNotFoundError:
|
||||
"""Tests for ContextNotFoundError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic not found error."""
|
||||
error = ContextNotFoundError()
|
||||
assert "not found" in error.message.lower()
|
||||
|
||||
def test_with_source_info(self) -> None:
|
||||
"""Test with source information."""
|
||||
error = ContextNotFoundError(
|
||||
source="knowledge-base",
|
||||
query="authentication flow",
|
||||
)
|
||||
|
||||
assert error.source == "knowledge-base"
|
||||
assert error.query == "authentication flow"
|
||||
|
||||
|
||||
class TestInvalidContextError:
|
||||
"""Tests for InvalidContextError."""
|
||||
|
||||
def test_basic_error(self) -> None:
|
||||
"""Test basic invalid error."""
|
||||
error = InvalidContextError()
|
||||
assert "invalid" in error.message.lower()
|
||||
|
||||
def test_with_field_info(self) -> None:
|
||||
"""Test with field information."""
|
||||
error = InvalidContextError(
|
||||
field="content",
|
||||
value="",
|
||||
reason="Content cannot be empty",
|
||||
)
|
||||
|
||||
assert error.field == "content"
|
||||
assert error.value == ""
|
||||
assert error.reason == "Content cannot be empty"
|
||||
|
||||
def test_value_type_only_in_details(self) -> None:
|
||||
"""Test that only value type is included in details (not actual value)."""
|
||||
error = InvalidContextError(
|
||||
field="api_key",
|
||||
value="secret-key-here",
|
||||
)
|
||||
|
||||
# Actual value should not be in details
|
||||
assert "secret-key-here" not in str(error.details)
|
||||
assert error.details["value_type"] == "str"
|
||||
579
backend/tests/services/context/test_types.py
Normal file
579
backend/tests/services/context/test_types.py
Normal file
@@ -0,0 +1,579 @@
|
||||
"""Tests for context types."""
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.context.types import (
|
||||
AssembledContext,
|
||||
BaseContext,
|
||||
ContextPriority,
|
||||
ContextType,
|
||||
ConversationContext,
|
||||
KnowledgeContext,
|
||||
MessageRole,
|
||||
SystemContext,
|
||||
TaskComplexity,
|
||||
TaskContext,
|
||||
TaskStatus,
|
||||
ToolContext,
|
||||
ToolResultStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestContextType:
|
||||
"""Tests for ContextType enum."""
|
||||
|
||||
def test_all_types_exist(self) -> None:
|
||||
"""Test that all expected context types exist."""
|
||||
assert ContextType.SYSTEM
|
||||
assert ContextType.TASK
|
||||
assert ContextType.KNOWLEDGE
|
||||
assert ContextType.CONVERSATION
|
||||
assert ContextType.TOOL
|
||||
|
||||
def test_from_string_valid(self) -> None:
|
||||
"""Test from_string with valid values."""
|
||||
assert ContextType.from_string("system") == ContextType.SYSTEM
|
||||
assert ContextType.from_string("KNOWLEDGE") == ContextType.KNOWLEDGE
|
||||
assert ContextType.from_string("Task") == ContextType.TASK
|
||||
|
||||
def test_from_string_invalid(self) -> None:
|
||||
"""Test from_string with invalid value."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ContextType.from_string("invalid")
|
||||
assert "Invalid context type" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestContextPriority:
|
||||
"""Tests for ContextPriority enum."""
|
||||
|
||||
def test_priority_ordering(self) -> None:
|
||||
"""Test that priorities are ordered correctly."""
|
||||
assert ContextPriority.LOWEST.value < ContextPriority.LOW.value
|
||||
assert ContextPriority.LOW.value < ContextPriority.NORMAL.value
|
||||
assert ContextPriority.NORMAL.value < ContextPriority.HIGH.value
|
||||
assert ContextPriority.HIGH.value < ContextPriority.HIGHEST.value
|
||||
assert ContextPriority.HIGHEST.value < ContextPriority.CRITICAL.value
|
||||
|
||||
def test_from_int(self) -> None:
|
||||
"""Test from_int conversion."""
|
||||
assert ContextPriority.from_int(0) == ContextPriority.LOWEST
|
||||
assert ContextPriority.from_int(50) == ContextPriority.NORMAL
|
||||
assert ContextPriority.from_int(100) == ContextPriority.HIGHEST
|
||||
assert ContextPriority.from_int(200) == ContextPriority.CRITICAL
|
||||
|
||||
def test_from_int_intermediate(self) -> None:
|
||||
"""Test from_int with intermediate values."""
|
||||
# Should return closest lower priority
|
||||
assert ContextPriority.from_int(30) == ContextPriority.LOW
|
||||
assert ContextPriority.from_int(60) == ContextPriority.NORMAL
|
||||
|
||||
|
||||
class TestSystemContext:
|
||||
"""Tests for SystemContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = SystemContext(
|
||||
content="You are a helpful assistant.",
|
||||
source="system_prompt",
|
||||
)
|
||||
|
||||
assert ctx.content == "You are a helpful assistant."
|
||||
assert ctx.source == "system_prompt"
|
||||
assert ctx.get_type() == ContextType.SYSTEM
|
||||
|
||||
def test_default_high_priority(self) -> None:
|
||||
"""Test that system context defaults to high priority."""
|
||||
ctx = SystemContext(content="Test", source="test")
|
||||
assert ctx.priority == ContextPriority.HIGH.value
|
||||
|
||||
def test_create_persona(self) -> None:
|
||||
"""Test create_persona factory method."""
|
||||
ctx = SystemContext.create_persona(
|
||||
name="Code Assistant",
|
||||
description="A helpful coding assistant.",
|
||||
capabilities=["Write code", "Debug"],
|
||||
constraints=["Never expose secrets"],
|
||||
)
|
||||
|
||||
assert "Code Assistant" in ctx.content
|
||||
assert "helpful coding assistant" in ctx.content
|
||||
assert "Write code" in ctx.content
|
||||
assert "Never expose secrets" in ctx.content
|
||||
assert ctx.priority == ContextPriority.HIGHEST.value
|
||||
|
||||
def test_create_instructions(self) -> None:
|
||||
"""Test create_instructions factory method."""
|
||||
ctx = SystemContext.create_instructions(
|
||||
["Always be helpful", "Be concise"],
|
||||
source="rules",
|
||||
)
|
||||
|
||||
assert "Always be helpful" in ctx.content
|
||||
assert "Be concise" in ctx.content
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test serialization to dict."""
|
||||
ctx = SystemContext(
|
||||
content="Test",
|
||||
source="test",
|
||||
role="assistant",
|
||||
instructions_type="general",
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
assert data["role"] == "assistant"
|
||||
assert data["instructions_type"] == "general"
|
||||
assert data["type"] == "system"
|
||||
|
||||
|
||||
class TestKnowledgeContext:
|
||||
"""Tests for KnowledgeContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = KnowledgeContext(
|
||||
content="def authenticate(user): ...",
|
||||
source="/src/auth.py",
|
||||
collection="code",
|
||||
file_type="python",
|
||||
)
|
||||
|
||||
assert ctx.content == "def authenticate(user): ..."
|
||||
assert ctx.source == "/src/auth.py"
|
||||
assert ctx.collection == "code"
|
||||
assert ctx.get_type() == ContextType.KNOWLEDGE
|
||||
|
||||
def test_from_search_result(self) -> None:
|
||||
"""Test from_search_result factory method."""
|
||||
result = {
|
||||
"content": "Test content",
|
||||
"source_path": "/test/file.py",
|
||||
"collection": "code",
|
||||
"file_type": "python",
|
||||
"chunk_index": 2,
|
||||
"score": 0.85,
|
||||
"id": "chunk-123",
|
||||
}
|
||||
|
||||
ctx = KnowledgeContext.from_search_result(result, "test query")
|
||||
|
||||
assert ctx.content == "Test content"
|
||||
assert ctx.source == "/test/file.py"
|
||||
assert ctx.relevance_score == 0.85
|
||||
assert ctx.search_query == "test query"
|
||||
|
||||
def test_from_search_results(self) -> None:
|
||||
"""Test from_search_results factory method."""
|
||||
results = [
|
||||
{"content": "Content 1", "source_path": "/a.py", "score": 0.9},
|
||||
{"content": "Content 2", "source_path": "/b.py", "score": 0.8},
|
||||
]
|
||||
|
||||
contexts = KnowledgeContext.from_search_results(results, "query")
|
||||
|
||||
assert len(contexts) == 2
|
||||
assert contexts[0].relevance_score == 0.9
|
||||
assert contexts[1].source == "/b.py"
|
||||
|
||||
def test_is_code(self) -> None:
|
||||
"""Test is_code method."""
|
||||
code_ctx = KnowledgeContext(
|
||||
content="code", source="test", file_type="python"
|
||||
)
|
||||
doc_ctx = KnowledgeContext(
|
||||
content="docs", source="test", file_type="markdown"
|
||||
)
|
||||
|
||||
assert code_ctx.is_code() is True
|
||||
assert doc_ctx.is_code() is False
|
||||
|
||||
def test_is_documentation(self) -> None:
|
||||
"""Test is_documentation method."""
|
||||
doc_ctx = KnowledgeContext(
|
||||
content="docs", source="test", file_type="markdown"
|
||||
)
|
||||
code_ctx = KnowledgeContext(
|
||||
content="code", source="test", file_type="python"
|
||||
)
|
||||
|
||||
assert doc_ctx.is_documentation() is True
|
||||
assert code_ctx.is_documentation() is False
|
||||
|
||||
|
||||
class TestConversationContext:
|
||||
"""Tests for ConversationContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = ConversationContext(
|
||||
content="Hello, how can I help?",
|
||||
source="conversation",
|
||||
role=MessageRole.ASSISTANT,
|
||||
turn_index=1,
|
||||
)
|
||||
|
||||
assert ctx.content == "Hello, how can I help?"
|
||||
assert ctx.role == MessageRole.ASSISTANT
|
||||
assert ctx.get_type() == ContextType.CONVERSATION
|
||||
|
||||
def test_from_message(self) -> None:
|
||||
"""Test from_message factory method."""
|
||||
ctx = ConversationContext.from_message(
|
||||
content="What is Python?",
|
||||
role="user",
|
||||
turn_index=0,
|
||||
)
|
||||
|
||||
assert ctx.content == "What is Python?"
|
||||
assert ctx.role == MessageRole.USER
|
||||
assert ctx.turn_index == 0
|
||||
|
||||
def test_from_history(self) -> None:
|
||||
"""Test from_history factory method."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "Help me"},
|
||||
]
|
||||
|
||||
contexts = ConversationContext.from_history(messages)
|
||||
|
||||
assert len(contexts) == 3
|
||||
assert contexts[0].role == MessageRole.USER
|
||||
assert contexts[1].role == MessageRole.ASSISTANT
|
||||
assert contexts[2].turn_index == 2
|
||||
|
||||
def test_is_user_message(self) -> None:
|
||||
"""Test is_user_message method."""
|
||||
user_ctx = ConversationContext(
|
||||
content="test", source="test", role=MessageRole.USER
|
||||
)
|
||||
assistant_ctx = ConversationContext(
|
||||
content="test", source="test", role=MessageRole.ASSISTANT
|
||||
)
|
||||
|
||||
assert user_ctx.is_user_message() is True
|
||||
assert assistant_ctx.is_user_message() is False
|
||||
|
||||
def test_format_for_prompt(self) -> None:
|
||||
"""Test format_for_prompt method."""
|
||||
ctx = ConversationContext.from_message(
|
||||
content="What is 2+2?",
|
||||
role="user",
|
||||
)
|
||||
|
||||
formatted = ctx.format_for_prompt()
|
||||
assert "User:" in formatted
|
||||
assert "What is 2+2?" in formatted
|
||||
|
||||
|
||||
class TestTaskContext:
|
||||
"""Tests for TaskContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = TaskContext(
|
||||
content="Implement login feature",
|
||||
source="task",
|
||||
title="Login Feature",
|
||||
)
|
||||
|
||||
assert ctx.content == "Implement login feature"
|
||||
assert ctx.title == "Login Feature"
|
||||
assert ctx.get_type() == ContextType.TASK
|
||||
|
||||
def test_default_high_priority(self) -> None:
|
||||
"""Test that task context defaults to high priority."""
|
||||
ctx = TaskContext(content="Test", source="test")
|
||||
assert ctx.priority == ContextPriority.HIGH.value
|
||||
|
||||
def test_create_factory(self) -> None:
|
||||
"""Test create factory method."""
|
||||
ctx = TaskContext.create(
|
||||
title="Add Auth",
|
||||
description="Implement authentication",
|
||||
acceptance_criteria=["Tests pass", "Code reviewed"],
|
||||
constraints=["Use JWT"],
|
||||
issue_id="123",
|
||||
)
|
||||
|
||||
assert ctx.title == "Add Auth"
|
||||
assert ctx.content == "Implement authentication"
|
||||
assert len(ctx.acceptance_criteria) == 2
|
||||
assert "Use JWT" in ctx.constraints
|
||||
assert ctx.status == TaskStatus.IN_PROGRESS
|
||||
|
||||
def test_format_for_prompt(self) -> None:
|
||||
"""Test format_for_prompt method."""
|
||||
ctx = TaskContext.create(
|
||||
title="Test Task",
|
||||
description="Do something",
|
||||
acceptance_criteria=["Works correctly"],
|
||||
)
|
||||
|
||||
formatted = ctx.format_for_prompt()
|
||||
assert "Task: Test Task" in formatted
|
||||
assert "Do something" in formatted
|
||||
assert "Works correctly" in formatted
|
||||
|
||||
def test_status_checks(self) -> None:
|
||||
"""Test status check methods."""
|
||||
pending = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.PENDING
|
||||
)
|
||||
completed = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.COMPLETED
|
||||
)
|
||||
blocked = TaskContext(
|
||||
content="test", source="test", status=TaskStatus.BLOCKED
|
||||
)
|
||||
|
||||
assert pending.is_active() is True
|
||||
assert completed.is_complete() is True
|
||||
assert blocked.is_blocked() is True
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
"""Tests for ToolContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = ToolContext(
|
||||
content="Tool result here",
|
||||
source="tool:search",
|
||||
tool_name="search",
|
||||
)
|
||||
|
||||
assert ctx.tool_name == "search"
|
||||
assert ctx.get_type() == ContextType.TOOL
|
||||
|
||||
def test_from_tool_definition(self) -> None:
|
||||
"""Test from_tool_definition factory method."""
|
||||
ctx = ToolContext.from_tool_definition(
|
||||
name="search_knowledge",
|
||||
description="Search the knowledge base",
|
||||
parameters={
|
||||
"query": {"type": "string", "required": True},
|
||||
"limit": {"type": "integer", "required": False},
|
||||
},
|
||||
server_name="knowledge-base",
|
||||
)
|
||||
|
||||
assert ctx.tool_name == "search_knowledge"
|
||||
assert "Search the knowledge base" in ctx.content
|
||||
assert ctx.is_result is False
|
||||
assert ctx.server_name == "knowledge-base"
|
||||
|
||||
def test_from_tool_result(self) -> None:
|
||||
"""Test from_tool_result factory method."""
|
||||
ctx = ToolContext.from_tool_result(
|
||||
tool_name="search",
|
||||
result={"found": 5, "items": ["a", "b"]},
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
execution_time_ms=150.5,
|
||||
)
|
||||
|
||||
assert ctx.tool_name == "search"
|
||||
assert ctx.is_result is True
|
||||
assert ctx.result_status == ToolResultStatus.SUCCESS
|
||||
assert "found" in ctx.content
|
||||
|
||||
def test_is_successful(self) -> None:
|
||||
"""Test is_successful method."""
|
||||
success = ToolContext.from_tool_result(
|
||||
"test", "ok", ToolResultStatus.SUCCESS
|
||||
)
|
||||
error = ToolContext.from_tool_result(
|
||||
"test", "error", ToolResultStatus.ERROR
|
||||
)
|
||||
|
||||
assert success.is_successful() is True
|
||||
assert error.is_successful() is False
|
||||
|
||||
def test_format_for_prompt(self) -> None:
|
||||
"""Test format_for_prompt method."""
|
||||
ctx = ToolContext.from_tool_result(
|
||||
"search",
|
||||
"Found 3 results",
|
||||
ToolResultStatus.SUCCESS,
|
||||
)
|
||||
|
||||
formatted = ctx.format_for_prompt()
|
||||
assert "Tool Result" in formatted
|
||||
assert "search" in formatted
|
||||
assert "success" in formatted
|
||||
|
||||
|
||||
class TestAssembledContext:
|
||||
"""Tests for AssembledContext."""
|
||||
|
||||
def test_creation(self) -> None:
|
||||
"""Test basic creation."""
|
||||
ctx = AssembledContext(
|
||||
content="Assembled content here",
|
||||
token_count=500,
|
||||
contexts_included=5,
|
||||
)
|
||||
|
||||
assert ctx.content == "Assembled content here"
|
||||
assert ctx.token_count == 500
|
||||
assert ctx.contexts_included == 5
|
||||
|
||||
def test_budget_utilization(self) -> None:
|
||||
"""Test budget_utilization property."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
token_count=800,
|
||||
contexts_included=5,
|
||||
budget_total=1000,
|
||||
budget_used=800,
|
||||
)
|
||||
|
||||
assert ctx.budget_utilization == 0.8
|
||||
|
||||
def test_budget_utilization_zero_budget(self) -> None:
|
||||
"""Test budget_utilization with zero budget."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
token_count=0,
|
||||
contexts_included=0,
|
||||
budget_total=0,
|
||||
budget_used=0,
|
||||
)
|
||||
|
||||
assert ctx.budget_utilization == 0.0
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
"""Test to_dict method."""
|
||||
ctx = AssembledContext(
|
||||
content="test",
|
||||
token_count=100,
|
||||
contexts_included=2,
|
||||
assembly_time_ms=50.123,
|
||||
)
|
||||
|
||||
data = ctx.to_dict()
|
||||
assert data["content"] == "test"
|
||||
assert data["token_count"] == 100
|
||||
assert data["assembly_time_ms"] == 50.12 # Rounded
|
||||
|
||||
def test_to_json_and_from_json(self) -> None:
|
||||
"""Test JSON serialization round-trip."""
|
||||
original = AssembledContext(
|
||||
content="Test content",
|
||||
token_count=100,
|
||||
contexts_included=3,
|
||||
contexts_excluded=2,
|
||||
assembly_time_ms=45.5,
|
||||
budget_total=1000,
|
||||
budget_used=100,
|
||||
by_type={"system": 20, "knowledge": 80},
|
||||
cache_hit=True,
|
||||
cache_key="abc123",
|
||||
)
|
||||
|
||||
json_str = original.to_json()
|
||||
restored = AssembledContext.from_json(json_str)
|
||||
|
||||
assert restored.content == original.content
|
||||
assert restored.token_count == original.token_count
|
||||
assert restored.contexts_included == original.contexts_included
|
||||
assert restored.cache_hit == original.cache_hit
|
||||
assert restored.cache_key == original.cache_key
|
||||
|
||||
|
||||
class TestBaseContextMethods:
|
||||
"""Tests for BaseContext methods."""
|
||||
|
||||
def test_get_age_seconds(self) -> None:
|
||||
"""Test get_age_seconds method."""
|
||||
old_time = datetime.now(UTC) - timedelta(hours=2)
|
||||
ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
|
||||
age = ctx.get_age_seconds()
|
||||
# Should be approximately 2 hours in seconds
|
||||
assert 7100 < age < 7300
|
||||
|
||||
def test_get_age_hours(self) -> None:
|
||||
"""Test get_age_hours method."""
|
||||
old_time = datetime.now(UTC) - timedelta(hours=5)
|
||||
ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
|
||||
age = ctx.get_age_hours()
|
||||
assert 4.9 < age < 5.1
|
||||
|
||||
def test_is_stale(self) -> None:
|
||||
"""Test is_stale method."""
|
||||
old_time = datetime.now(UTC) - timedelta(days=10)
|
||||
new_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
|
||||
old_ctx = SystemContext(
|
||||
content="test", source="test", timestamp=old_time
|
||||
)
|
||||
new_ctx = SystemContext(
|
||||
content="test", source="test", timestamp=new_time
|
||||
)
|
||||
|
||||
# Default max_age is 168 hours (7 days)
|
||||
assert old_ctx.is_stale() is True
|
||||
assert new_ctx.is_stale() is False
|
||||
|
||||
def test_token_count_property(self) -> None:
|
||||
"""Test token_count property."""
|
||||
ctx = SystemContext(content="test", source="test")
|
||||
|
||||
# Initially None
|
||||
assert ctx.token_count is None
|
||||
|
||||
# Can be set
|
||||
ctx.token_count = 100
|
||||
assert ctx.token_count == 100
|
||||
|
||||
def test_score_property_clamping(self) -> None:
|
||||
"""Test that score is clamped to 0.0-1.0."""
|
||||
ctx = SystemContext(content="test", source="test")
|
||||
|
||||
ctx.score = 1.5
|
||||
assert ctx.score == 1.0
|
||||
|
||||
ctx.score = -0.5
|
||||
assert ctx.score == 0.0
|
||||
|
||||
ctx.score = 0.75
|
||||
assert ctx.score == 0.75
|
||||
|
||||
def test_hash_and_equality(self) -> None:
|
||||
"""Test hash and equality based on ID."""
|
||||
ctx1 = SystemContext(content="test", source="test")
|
||||
ctx2 = SystemContext(content="test", source="test")
|
||||
ctx3 = SystemContext(content="test", source="test")
|
||||
ctx3.id = ctx1.id # Same ID as ctx1
|
||||
|
||||
# Different IDs = not equal
|
||||
assert ctx1 != ctx2
|
||||
|
||||
# Same ID = equal
|
||||
assert ctx1 == ctx3
|
||||
|
||||
# Can be used in sets
|
||||
context_set = {ctx1, ctx2, ctx3}
|
||||
assert len(context_set) == 2 # ctx1 and ctx3 are same
|
||||
|
||||
def test_truncate(self) -> None:
|
||||
"""Test truncate method."""
|
||||
long_content = "word " * 1000 # Long content
|
||||
ctx = SystemContext(content=long_content, source="test")
|
||||
ctx.token_count = 1000
|
||||
|
||||
truncated = ctx.truncate(100)
|
||||
|
||||
assert len(truncated) < len(long_content)
|
||||
assert "[truncated]" in truncated
|
||||
Reference in New Issue
Block a user