feat(context): Phase 1 - Foundation types, config and exceptions (#79)

Implements the foundation for Context Management Engine:

Types (backend/app/services/context/types/):
- BaseContext: Abstract base with ID, content, priority, scoring
- SystemContext: System prompts, personas, instructions
- KnowledgeContext: RAG results from Knowledge Base MCP
- ConversationContext: Chat history with role support
- TaskContext: Task/issue context with acceptance criteria
- ToolContext: Tool definitions and execution results
- AssembledContext: Final assembled context result

Configuration (config.py):
- Token budget allocation (system 5%, task 10%, knowledge 40%, etc.)
- Scoring weights (relevance 50%, recency 30%, priority 20%)
- Cache settings (TTL, prefix)
- Performance settings (max assembly time, parallel scoring)
- Environment variable overrides with CTX_ prefix

Exceptions (exceptions.py):
- ContextError: Base exception
- BudgetExceededError: Token budget violations
- TokenCountError: Token counting failures
- CompressionError: Compression failures
- AssemblyTimeoutError: Assembly timeout
- ScoringError, FormattingError, CacheError
- ContextNotFoundError, InvalidContextError

All 86 tests pass.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 02:07:39 +01:00
parent 2ab69f8561
commit 22ecb5e989
21 changed files with 3131 additions and 0 deletions

View File

@@ -0,0 +1,105 @@
"""
Context Management Engine
Sophisticated context assembly and optimization for LLM requests.
Provides intelligent context selection, token budget management,
and model-specific formatting.
Usage:
from app.services.context import (
ContextSettings,
get_context_settings,
SystemContext,
KnowledgeContext,
ConversationContext,
TaskContext,
ToolContext,
)
# Get settings
settings = get_context_settings()
# Create context instances
system_ctx = SystemContext.create_persona(
name="Code Assistant",
description="You are a helpful code assistant.",
capabilities=["Write code", "Debug issues"],
)
"""
# Configuration
from .config import (
ContextSettings,
get_context_settings,
get_default_settings,
reset_context_settings,
)
# Exceptions
from .exceptions import (
AssemblyTimeoutError,
BudgetExceededError,
CacheError,
CompressionError,
ContextError,
ContextNotFoundError,
FormattingError,
InvalidContextError,
ScoringError,
TokenCountError,
)
# Types
from .types import (
AssembledContext,
BaseContext,
ContextPriority,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskComplexity,
TaskContext,
TaskStatus,
ToolContext,
ToolResultStatus,
)
__all__ = [
# Configuration
"ContextSettings",
"get_context_settings",
"get_default_settings",
"reset_context_settings",
# Exceptions
"AssemblyTimeoutError",
"BudgetExceededError",
"CacheError",
"CompressionError",
"ContextError",
"ContextNotFoundError",
"FormattingError",
"InvalidContextError",
"ScoringError",
"TokenCountError",
# Types - Base
"AssembledContext",
"BaseContext",
"ContextPriority",
"ContextType",
# Types - Conversation
"ConversationContext",
"MessageRole",
# Types - Knowledge
"KnowledgeContext",
# Types - System
"SystemContext",
# Types - Task
"TaskComplexity",
"TaskContext",
"TaskStatus",
# Types - Tool
"ToolContext",
"ToolResultStatus",
]

View File

@@ -0,0 +1,5 @@
"""
Model Adapters Module.
Provides model-specific context formatting.
"""

View File

@@ -0,0 +1,5 @@
"""
Context Assembly Module.
Provides the assembly pipeline and formatting.
"""

View File

@@ -0,0 +1,5 @@
"""
Token Budget Management Module.
Provides token counting and budget allocation.
"""

View File

@@ -0,0 +1,5 @@
"""
Context Cache Module.
Provides Redis-based caching for assembled contexts.
"""

View File

@@ -0,0 +1,5 @@
"""
Context Compression Module.
Provides truncation and compression strategies.
"""

View File

@@ -0,0 +1,328 @@
"""
Context Management Engine Configuration.
Provides Pydantic settings for context assembly,
token budget allocation, and caching.
"""
import threading
from functools import lru_cache
from typing import Any
from pydantic import Field, field_validator, model_validator
from pydantic_settings import BaseSettings
class ContextSettings(BaseSettings):
"""
Configuration for the Context Management Engine.
All settings can be overridden via environment variables
with the CTX_ prefix.
"""
# Budget allocation percentages (must sum to 1.0)
budget_system: float = Field(
default=0.05,
ge=0.0,
le=1.0,
description="Percentage of budget for system prompts (5%)",
)
budget_task: float = Field(
default=0.10,
ge=0.0,
le=1.0,
description="Percentage of budget for task context (10%)",
)
budget_knowledge: float = Field(
default=0.40,
ge=0.0,
le=1.0,
description="Percentage of budget for RAG/knowledge (40%)",
)
budget_conversation: float = Field(
default=0.20,
ge=0.0,
le=1.0,
description="Percentage of budget for conversation history (20%)",
)
budget_tools: float = Field(
default=0.05,
ge=0.0,
le=1.0,
description="Percentage of budget for tool descriptions (5%)",
)
budget_response: float = Field(
default=0.15,
ge=0.0,
le=1.0,
description="Percentage reserved for response (15%)",
)
budget_buffer: float = Field(
default=0.05,
ge=0.0,
le=1.0,
description="Percentage buffer for safety margin (5%)",
)
# Scoring weights
scoring_relevance_weight: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Weight for relevance scoring",
)
scoring_recency_weight: float = Field(
default=0.3,
ge=0.0,
le=1.0,
description="Weight for recency scoring",
)
scoring_priority_weight: float = Field(
default=0.2,
ge=0.0,
le=1.0,
description="Weight for priority scoring",
)
# Recency decay settings
recency_decay_hours: float = Field(
default=24.0,
gt=0.0,
description="Hours until recency score decays to 50%",
)
recency_max_age_hours: float = Field(
default=168.0,
gt=0.0,
description="Hours until context is considered stale (7 days)",
)
# Compression settings
compression_threshold: float = Field(
default=0.8,
ge=0.0,
le=1.0,
description="Compress when budget usage exceeds this percentage",
)
truncation_suffix: str = Field(
default="... [truncated]",
description="Suffix to add when truncating content",
)
summary_model_group: str = Field(
default="fast",
description="Model group to use for summarization",
)
# Caching settings
cache_enabled: bool = Field(
default=True,
description="Enable Redis caching for assembled contexts",
)
cache_ttl_seconds: int = Field(
default=3600,
ge=60,
le=86400,
description="Cache TTL in seconds (1 hour default, max 24 hours)",
)
cache_prefix: str = Field(
default="ctx",
description="Redis key prefix for context cache",
)
# Performance settings
max_assembly_time_ms: int = Field(
default=100,
ge=10,
le=5000,
description="Maximum time for context assembly in milliseconds",
)
parallel_scoring: bool = Field(
default=True,
description="Score contexts in parallel for better performance",
)
max_parallel_scores: int = Field(
default=10,
ge=1,
le=50,
description="Maximum number of contexts to score in parallel",
)
# Knowledge retrieval settings
knowledge_search_type: str = Field(
default="hybrid",
description="Default search type for knowledge retrieval",
)
knowledge_max_results: int = Field(
default=10,
ge=1,
le=50,
description="Maximum knowledge chunks to retrieve",
)
knowledge_min_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Minimum relevance score for knowledge",
)
# Conversation history settings
conversation_max_turns: int = Field(
default=20,
ge=1,
le=100,
description="Maximum conversation turns to include",
)
conversation_recent_priority: bool = Field(
default=True,
description="Prioritize recent conversation turns",
)
@field_validator("knowledge_search_type")
@classmethod
def validate_search_type(cls, v: str) -> str:
"""Validate search type is valid."""
valid_types = {"semantic", "keyword", "hybrid"}
if v not in valid_types:
raise ValueError(f"search_type must be one of: {valid_types}")
return v
@model_validator(mode="after")
def validate_budget_allocation(self) -> "ContextSettings":
"""Validate that budget percentages sum to 1.0."""
total = (
self.budget_system
+ self.budget_task
+ self.budget_knowledge
+ self.budget_conversation
+ self.budget_tools
+ self.budget_response
+ self.budget_buffer
)
# Allow small floating point error
if abs(total - 1.0) > 0.001:
raise ValueError(
f"Budget percentages must sum to 1.0, got {total:.3f}. "
f"Current allocation: system={self.budget_system}, task={self.budget_task}, "
f"knowledge={self.budget_knowledge}, conversation={self.budget_conversation}, "
f"tools={self.budget_tools}, response={self.budget_response}, buffer={self.budget_buffer}"
)
return self
@model_validator(mode="after")
def validate_scoring_weights(self) -> "ContextSettings":
"""Validate that scoring weights sum to 1.0."""
total = (
self.scoring_relevance_weight
+ self.scoring_recency_weight
+ self.scoring_priority_weight
)
# Allow small floating point error
if abs(total - 1.0) > 0.001:
raise ValueError(
f"Scoring weights must sum to 1.0, got {total:.3f}. "
f"Current weights: relevance={self.scoring_relevance_weight}, "
f"recency={self.scoring_recency_weight}, priority={self.scoring_priority_weight}"
)
return self
def get_budget_allocation(self) -> dict[str, float]:
"""Get budget allocation as a dictionary."""
return {
"system": self.budget_system,
"task": self.budget_task,
"knowledge": self.budget_knowledge,
"conversation": self.budget_conversation,
"tools": self.budget_tools,
"response": self.budget_response,
"buffer": self.budget_buffer,
}
def get_scoring_weights(self) -> dict[str, float]:
"""Get scoring weights as a dictionary."""
return {
"relevance": self.scoring_relevance_weight,
"recency": self.scoring_recency_weight,
"priority": self.scoring_priority_weight,
}
def to_dict(self) -> dict[str, Any]:
"""Convert settings to dictionary for logging/debugging."""
return {
"budget": self.get_budget_allocation(),
"scoring": self.get_scoring_weights(),
"compression": {
"threshold": self.compression_threshold,
"summary_model_group": self.summary_model_group,
},
"cache": {
"enabled": self.cache_enabled,
"ttl_seconds": self.cache_ttl_seconds,
"prefix": self.cache_prefix,
},
"performance": {
"max_assembly_time_ms": self.max_assembly_time_ms,
"parallel_scoring": self.parallel_scoring,
"max_parallel_scores": self.max_parallel_scores,
},
"knowledge": {
"search_type": self.knowledge_search_type,
"max_results": self.knowledge_max_results,
"min_score": self.knowledge_min_score,
},
"conversation": {
"max_turns": self.conversation_max_turns,
"recent_priority": self.conversation_recent_priority,
},
}
model_config = {
"env_prefix": "CTX_",
"env_file": "../.env",
"env_file_encoding": "utf-8",
"case_sensitive": False,
"extra": "ignore",
}
# Thread-safe singleton pattern
_settings: ContextSettings | None = None
_settings_lock = threading.Lock()
def get_context_settings() -> ContextSettings:
"""
Get the global ContextSettings instance.
Thread-safe with double-checked locking pattern.
Returns:
ContextSettings instance
"""
global _settings
if _settings is None:
with _settings_lock:
if _settings is None:
_settings = ContextSettings()
return _settings
def reset_context_settings() -> None:
"""
Reset the global settings instance.
Primarily used for testing.
"""
global _settings
with _settings_lock:
_settings = None
@lru_cache(maxsize=1)
def get_default_settings() -> ContextSettings:
"""
Get default settings (cached).
Use this for read-only access to defaults.
For mutable access, use get_context_settings().
"""
return ContextSettings()

View File

@@ -0,0 +1,354 @@
"""
Context Management Engine Exceptions.
Provides a hierarchy of exceptions for context assembly,
token budget management, and related operations.
"""
from typing import Any
class ContextError(Exception):
"""
Base exception for all context management errors.
All context-related exceptions should inherit from this class
to allow for catch-all handling when needed.
"""
def __init__(self, message: str, details: dict[str, Any] | None = None) -> None:
"""
Initialize context error.
Args:
message: Human-readable error message
details: Optional dict with additional error context
"""
self.message = message
self.details = details or {}
super().__init__(message)
def to_dict(self) -> dict[str, Any]:
"""Convert exception to dictionary for logging/serialization."""
return {
"error_type": self.__class__.__name__,
"message": self.message,
"details": self.details,
}
class BudgetExceededError(ContextError):
"""
Raised when token budget is exceeded.
This occurs when the assembled context would exceed the
allocated token budget for a specific context type or total.
"""
def __init__(
self,
message: str = "Token budget exceeded",
allocated: int = 0,
requested: int = 0,
context_type: str | None = None,
) -> None:
"""
Initialize budget exceeded error.
Args:
message: Error message
allocated: Tokens allocated for this context type
requested: Tokens requested
context_type: Type of context that exceeded budget
"""
details = {
"allocated": allocated,
"requested": requested,
"overage": requested - allocated,
}
if context_type:
details["context_type"] = context_type
super().__init__(message, details)
self.allocated = allocated
self.requested = requested
self.context_type = context_type
class TokenCountError(ContextError):
"""
Raised when token counting fails.
This typically occurs when the LLM Gateway token counting
service is unavailable or returns an error.
"""
def __init__(
self,
message: str = "Failed to count tokens",
model: str | None = None,
text_length: int | None = None,
) -> None:
"""
Initialize token count error.
Args:
message: Error message
model: Model for which counting was attempted
text_length: Length of text that failed to count
"""
details: dict[str, Any] = {}
if model:
details["model"] = model
if text_length is not None:
details["text_length"] = text_length
super().__init__(message, details)
self.model = model
self.text_length = text_length
class CompressionError(ContextError):
"""
Raised when context compression fails.
This can occur when summarization or truncation cannot
reduce content to fit within the budget.
"""
def __init__(
self,
message: str = "Failed to compress context",
original_tokens: int | None = None,
target_tokens: int | None = None,
achieved_tokens: int | None = None,
) -> None:
"""
Initialize compression error.
Args:
message: Error message
original_tokens: Tokens before compression
target_tokens: Target token count
achieved_tokens: Tokens achieved after compression attempt
"""
details: dict[str, Any] = {}
if original_tokens is not None:
details["original_tokens"] = original_tokens
if target_tokens is not None:
details["target_tokens"] = target_tokens
if achieved_tokens is not None:
details["achieved_tokens"] = achieved_tokens
super().__init__(message, details)
self.original_tokens = original_tokens
self.target_tokens = target_tokens
self.achieved_tokens = achieved_tokens
class AssemblyTimeoutError(ContextError):
"""
Raised when context assembly exceeds time limit.
Context assembly must complete within a configurable
time limit to maintain responsiveness.
"""
def __init__(
self,
message: str = "Context assembly timed out",
timeout_ms: int = 0,
elapsed_ms: float = 0.0,
stage: str | None = None,
) -> None:
"""
Initialize assembly timeout error.
Args:
message: Error message
timeout_ms: Configured timeout in milliseconds
elapsed_ms: Actual elapsed time in milliseconds
stage: Pipeline stage where timeout occurred
"""
details = {
"timeout_ms": timeout_ms,
"elapsed_ms": round(elapsed_ms, 2),
}
if stage:
details["stage"] = stage
super().__init__(message, details)
self.timeout_ms = timeout_ms
self.elapsed_ms = elapsed_ms
self.stage = stage
class ScoringError(ContextError):
"""
Raised when context scoring fails.
This occurs when relevance, recency, or priority scoring
encounters an error.
"""
def __init__(
self,
message: str = "Failed to score context",
scorer_type: str | None = None,
context_id: str | None = None,
) -> None:
"""
Initialize scoring error.
Args:
message: Error message
scorer_type: Type of scorer that failed
context_id: ID of context being scored
"""
details: dict[str, Any] = {}
if scorer_type:
details["scorer_type"] = scorer_type
if context_id:
details["context_id"] = context_id
super().__init__(message, details)
self.scorer_type = scorer_type
self.context_id = context_id
class FormattingError(ContextError):
"""
Raised when context formatting fails.
This occurs when converting assembled context to
model-specific format fails.
"""
def __init__(
self,
message: str = "Failed to format context",
model: str | None = None,
adapter: str | None = None,
) -> None:
"""
Initialize formatting error.
Args:
message: Error message
model: Target model
adapter: Adapter that failed
"""
details: dict[str, Any] = {}
if model:
details["model"] = model
if adapter:
details["adapter"] = adapter
super().__init__(message, details)
self.model = model
self.adapter = adapter
class CacheError(ContextError):
"""
Raised when cache operations fail.
This is typically non-fatal and should be handled
gracefully by falling back to recomputation.
"""
def __init__(
self,
message: str = "Cache operation failed",
operation: str | None = None,
cache_key: str | None = None,
) -> None:
"""
Initialize cache error.
Args:
message: Error message
operation: Cache operation that failed (get, set, delete)
cache_key: Key involved in the failed operation
"""
details: dict[str, Any] = {}
if operation:
details["operation"] = operation
if cache_key:
details["cache_key"] = cache_key
super().__init__(message, details)
self.operation = operation
self.cache_key = cache_key
class ContextNotFoundError(ContextError):
"""
Raised when expected context is not found.
This occurs when required context sources return
no results or are unavailable.
"""
def __init__(
self,
message: str = "Required context not found",
source: str | None = None,
query: str | None = None,
) -> None:
"""
Initialize context not found error.
Args:
message: Error message
source: Source that returned no results
query: Query used to search
"""
details: dict[str, Any] = {}
if source:
details["source"] = source
if query:
details["query"] = query
super().__init__(message, details)
self.source = source
self.query = query
class InvalidContextError(ContextError):
"""
Raised when context data is invalid.
This occurs when context content or metadata
fails validation.
"""
def __init__(
self,
message: str = "Invalid context data",
field: str | None = None,
value: Any | None = None,
reason: str | None = None,
) -> None:
"""
Initialize invalid context error.
Args:
message: Error message
field: Field that is invalid
value: Invalid value (may be redacted for security)
reason: Reason for invalidity
"""
details: dict[str, Any] = {}
if field:
details["field"] = field
if value is not None:
# Avoid logging potentially sensitive values
details["value_type"] = type(value).__name__
if reason:
details["reason"] = reason
super().__init__(message, details)
self.field = field
self.value = value
self.reason = reason

View File

@@ -0,0 +1,5 @@
"""
Context Prioritization Module.
Provides context ranking and selection.
"""

View File

@@ -0,0 +1,5 @@
"""
Context Scoring Module.
Provides relevance, recency, and composite scoring.
"""

View File

@@ -0,0 +1,49 @@
"""
Context Types Module.
Provides all context types used in the Context Management Engine.
"""
from .base import (
AssembledContext,
BaseContext,
ContextPriority,
ContextType,
)
from .conversation import (
ConversationContext,
MessageRole,
)
from .knowledge import KnowledgeContext
from .system import SystemContext
from .task import (
TaskComplexity,
TaskContext,
TaskStatus,
)
from .tool import (
ToolContext,
ToolResultStatus,
)
__all__ = [
# Base types
"AssembledContext",
"BaseContext",
"ContextPriority",
"ContextType",
# Conversation
"ConversationContext",
"MessageRole",
# Knowledge
"KnowledgeContext",
# System
"SystemContext",
# Task
"TaskComplexity",
"TaskContext",
"TaskStatus",
# Tool
"ToolContext",
"ToolResultStatus",
]

View File

@@ -0,0 +1,320 @@
"""
Base Context Types and Enums.
Provides the foundation for all context types used in
the Context Management Engine.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from uuid import uuid4
class ContextType(str, Enum):
"""
Types of context that can be assembled.
Each type has specific handling, formatting, and
budget allocation rules.
"""
SYSTEM = "system"
TASK = "task"
KNOWLEDGE = "knowledge"
CONVERSATION = "conversation"
TOOL = "tool"
@classmethod
def from_string(cls, value: str) -> "ContextType":
"""
Convert string to ContextType.
Args:
value: String value
Returns:
ContextType enum value
Raises:
ValueError: If value is not a valid context type
"""
try:
return cls(value.lower())
except ValueError:
valid = ", ".join(t.value for t in cls)
raise ValueError(f"Invalid context type '{value}'. Valid types: {valid}")
class ContextPriority(int, Enum):
"""
Priority levels for context ordering.
Higher values indicate higher priority.
"""
LOWEST = 0
LOW = 25
NORMAL = 50
HIGH = 75
HIGHEST = 100
CRITICAL = 150 # Never omit
@classmethod
def from_int(cls, value: int) -> "ContextPriority":
"""
Get closest priority level for an integer.
Args:
value: Integer priority value
Returns:
Closest ContextPriority enum value
"""
priorities = sorted(cls, key=lambda p: p.value)
for priority in reversed(priorities):
if value >= priority.value:
return priority
return cls.LOWEST
@dataclass(eq=False)
class BaseContext(ABC):
"""
Abstract base class for all context types.
Provides common fields and methods for context handling,
scoring, and serialization.
"""
# Required fields
content: str
source: str
# Optional fields with defaults
id: str = field(default_factory=lambda: str(uuid4()))
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
priority: int = field(default=ContextPriority.NORMAL.value)
metadata: dict[str, Any] = field(default_factory=dict)
# Computed/cached fields
_token_count: int | None = field(default=None, repr=False)
_score: float | None = field(default=None, repr=False)
@property
def token_count(self) -> int | None:
"""Get cached token count (None if not counted yet)."""
return self._token_count
@token_count.setter
def token_count(self, value: int) -> None:
"""Set token count."""
self._token_count = value
@property
def score(self) -> float | None:
"""Get cached score (None if not scored yet)."""
return self._score
@score.setter
def score(self, value: float) -> None:
"""Set score (clamped to 0.0-1.0)."""
self._score = max(0.0, min(1.0, value))
@abstractmethod
def get_type(self) -> ContextType:
"""
Get the type of this context.
Returns:
ContextType enum value
"""
...
def get_age_seconds(self) -> float:
"""
Get age of context in seconds.
Returns:
Age in seconds since creation
"""
now = datetime.now(UTC)
delta = now - self.timestamp
return delta.total_seconds()
def get_age_hours(self) -> float:
"""
Get age of context in hours.
Returns:
Age in hours since creation
"""
return self.get_age_seconds() / 3600
def is_stale(self, max_age_hours: float = 168.0) -> bool:
"""
Check if context is stale.
Args:
max_age_hours: Maximum age before considered stale (default 7 days)
Returns:
True if context is older than max_age_hours
"""
return self.get_age_hours() > max_age_hours
def to_dict(self) -> dict[str, Any]:
"""
Convert context to dictionary for serialization.
Returns:
Dictionary representation
"""
return {
"id": self.id,
"type": self.get_type().value,
"content": self.content,
"source": self.source,
"timestamp": self.timestamp.isoformat(),
"priority": self.priority,
"metadata": self.metadata,
"token_count": self._token_count,
"score": self._score,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "BaseContext":
"""
Create context from dictionary.
Note: Subclasses should override this to return correct type.
Args:
data: Dictionary with context data
Returns:
Context instance
"""
raise NotImplementedError("Subclasses must implement from_dict")
def truncate(self, max_tokens: int, suffix: str = "... [truncated]") -> str:
"""
Truncate content to fit within token limit.
This is a rough estimation based on characters.
For accurate truncation, use the TokenCalculator.
Args:
max_tokens: Maximum tokens allowed
suffix: Suffix to append when truncated
Returns:
Truncated content
"""
if self._token_count is None or self._token_count <= max_tokens:
return self.content
# Rough estimation: 4 chars per token on average
estimated_chars = max_tokens * 4
suffix_chars = len(suffix)
if len(self.content) <= estimated_chars:
return self.content
truncated = self.content[: estimated_chars - suffix_chars]
# Try to break at word boundary
last_space = truncated.rfind(" ")
if last_space > estimated_chars * 0.8:
truncated = truncated[:last_space]
return truncated + suffix
def __hash__(self) -> int:
"""Hash based on ID for set/dict usage."""
return hash(self.id)
def __eq__(self, other: object) -> bool:
"""Equality based on ID."""
if not isinstance(other, BaseContext):
return False
return self.id == other.id
@dataclass
class AssembledContext:
"""
Result of context assembly.
Contains the final formatted context ready for LLM consumption,
along with metadata about the assembly process.
"""
# Main content
content: str
token_count: int
# Assembly metadata
contexts_included: int
contexts_excluded: int = 0
assembly_time_ms: float = 0.0
# Budget tracking
budget_total: int = 0
budget_used: int = 0
# Context breakdown
by_type: dict[str, int] = field(default_factory=dict)
# Cache info
cache_hit: bool = False
cache_key: str | None = None
@property
def budget_utilization(self) -> float:
"""Get budget utilization percentage."""
if self.budget_total == 0:
return 0.0
return self.budget_used / self.budget_total
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"content": self.content,
"token_count": self.token_count,
"contexts_included": self.contexts_included,
"contexts_excluded": self.contexts_excluded,
"assembly_time_ms": round(self.assembly_time_ms, 2),
"budget_total": self.budget_total,
"budget_used": self.budget_used,
"budget_utilization": round(self.budget_utilization, 3),
"by_type": self.by_type,
"cache_hit": self.cache_hit,
"cache_key": self.cache_key,
}
def to_json(self) -> str:
"""Convert to JSON string."""
import json
return json.dumps(self.to_dict())
@classmethod
def from_json(cls, json_str: str) -> "AssembledContext":
"""Create from JSON string."""
import json
data = json.loads(json_str)
return cls(
content=data["content"],
token_count=data["token_count"],
contexts_included=data["contexts_included"],
contexts_excluded=data.get("contexts_excluded", 0),
assembly_time_ms=data.get("assembly_time_ms", 0.0),
budget_total=data.get("budget_total", 0),
budget_used=data.get("budget_used", 0),
by_type=data.get("by_type", {}),
cache_hit=data.get("cache_hit", False),
cache_key=data.get("cache_key"),
)

View File

@@ -0,0 +1,182 @@
"""
Conversation Context Type.
Represents conversation history for context continuity.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from .base import BaseContext, ContextPriority, ContextType
class MessageRole(str, Enum):
"""Roles for conversation messages."""
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
TOOL = "tool"
@classmethod
def from_string(cls, value: str) -> "MessageRole":
"""Convert string to MessageRole."""
try:
return cls(value.lower())
except ValueError:
# Default to user for unknown roles
return cls.USER
@dataclass(eq=False)
class ConversationContext(BaseContext):
"""
Context from conversation history.
Represents a single turn in the conversation,
including user messages, assistant responses,
and tool results.
"""
# Conversation-specific fields
role: MessageRole = field(default=MessageRole.USER)
turn_index: int = field(default=0)
session_id: str | None = field(default=None)
parent_message_id: str | None = field(default=None)
def get_type(self) -> ContextType:
"""Return CONVERSATION context type."""
return ContextType.CONVERSATION
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary with conversation-specific fields."""
base = super().to_dict()
base.update(
{
"role": self.role.value,
"turn_index": self.turn_index,
"session_id": self.session_id,
"parent_message_id": self.parent_message_id,
}
)
return base
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ConversationContext":
"""Create ConversationContext from dictionary."""
role = data.get("role", "user")
if isinstance(role, str):
role = MessageRole.from_string(role)
return cls(
id=data.get("id", ""),
content=data["content"],
source=data.get("source", "conversation"),
timestamp=datetime.fromisoformat(data["timestamp"])
if isinstance(data.get("timestamp"), str)
else data.get("timestamp", datetime.now(UTC)),
priority=data.get("priority", ContextPriority.NORMAL.value),
metadata=data.get("metadata", {}),
role=role,
turn_index=data.get("turn_index", 0),
session_id=data.get("session_id"),
parent_message_id=data.get("parent_message_id"),
)
@classmethod
def from_message(
cls,
content: str,
role: str | MessageRole,
turn_index: int = 0,
session_id: str | None = None,
timestamp: datetime | None = None,
) -> "ConversationContext":
"""
Create ConversationContext from a message.
Args:
content: Message content
role: Message role (user, assistant, system, tool)
turn_index: Position in conversation
session_id: Session identifier
timestamp: Message timestamp
Returns:
ConversationContext instance
"""
if isinstance(role, str):
role = MessageRole.from_string(role)
# Recent messages have higher priority
priority = ContextPriority.NORMAL.value
return cls(
content=content,
source="conversation",
role=role,
turn_index=turn_index,
session_id=session_id,
timestamp=timestamp or datetime.now(UTC),
priority=priority,
)
@classmethod
def from_history(
cls,
messages: list[dict[str, Any]],
session_id: str | None = None,
) -> list["ConversationContext"]:
"""
Create multiple ConversationContexts from message history.
Args:
messages: List of message dicts with 'role' and 'content'
session_id: Session identifier
Returns:
List of ConversationContext instances
"""
contexts = []
for i, msg in enumerate(messages):
ctx = cls.from_message(
content=msg.get("content", ""),
role=msg.get("role", "user"),
turn_index=i,
session_id=session_id,
timestamp=datetime.fromisoformat(msg["timestamp"])
if "timestamp" in msg
else None,
)
contexts.append(ctx)
return contexts
def is_user_message(self) -> bool:
"""Check if this is a user message."""
return self.role == MessageRole.USER
def is_assistant_message(self) -> bool:
"""Check if this is an assistant message."""
return self.role == MessageRole.ASSISTANT
def is_tool_result(self) -> bool:
"""Check if this is a tool result."""
return self.role == MessageRole.TOOL
def format_for_prompt(self) -> str:
"""
Format message for inclusion in prompt.
Returns:
Formatted message string
"""
role_labels = {
MessageRole.USER: "User",
MessageRole.ASSISTANT: "Assistant",
MessageRole.SYSTEM: "System",
MessageRole.TOOL: "Tool Result",
}
label = role_labels.get(self.role, "Unknown")
return f"{label}: {self.content}"

View File

@@ -0,0 +1,143 @@
"""
Knowledge Context Type.
Represents RAG results from the Knowledge Base MCP server.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from .base import BaseContext, ContextPriority, ContextType
@dataclass(eq=False)
class KnowledgeContext(BaseContext):
"""
Context from knowledge base / RAG retrieval.
Knowledge context represents chunks retrieved from the
Knowledge Base MCP server, including:
- Code snippets
- Documentation
- Previous conversations
- External knowledge
Each chunk includes relevance scoring from the search.
"""
# Knowledge-specific fields
collection: str = field(default="default")
file_type: str | None = field(default=None)
chunk_index: int = field(default=0)
relevance_score: float = field(default=0.0)
search_query: str = field(default="")
def get_type(self) -> ContextType:
"""Return KNOWLEDGE context type."""
return ContextType.KNOWLEDGE
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary with knowledge-specific fields."""
base = super().to_dict()
base.update(
{
"collection": self.collection,
"file_type": self.file_type,
"chunk_index": self.chunk_index,
"relevance_score": self.relevance_score,
"search_query": self.search_query,
}
)
return base
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "KnowledgeContext":
"""Create KnowledgeContext from dictionary."""
return cls(
id=data.get("id", ""),
content=data["content"],
source=data["source"],
timestamp=datetime.fromisoformat(data["timestamp"])
if isinstance(data.get("timestamp"), str)
else data.get("timestamp", datetime.now(UTC)),
priority=data.get("priority", ContextPriority.NORMAL.value),
metadata=data.get("metadata", {}),
collection=data.get("collection", "default"),
file_type=data.get("file_type"),
chunk_index=data.get("chunk_index", 0),
relevance_score=data.get("relevance_score", 0.0),
search_query=data.get("search_query", ""),
)
@classmethod
def from_search_result(
cls,
result: dict[str, Any],
query: str,
) -> "KnowledgeContext":
"""
Create KnowledgeContext from a Knowledge Base search result.
Args:
result: Search result from Knowledge Base MCP
query: Search query used
Returns:
KnowledgeContext instance
"""
return cls(
content=result.get("content", ""),
source=result.get("source_path", "unknown"),
collection=result.get("collection", "default"),
file_type=result.get("file_type"),
chunk_index=result.get("chunk_index", 0),
relevance_score=result.get("score", 0.0),
search_query=query,
metadata={
"chunk_id": result.get("id"),
"content_hash": result.get("content_hash"),
},
)
@classmethod
def from_search_results(
cls,
results: list[dict[str, Any]],
query: str,
) -> list["KnowledgeContext"]:
"""
Create multiple KnowledgeContexts from search results.
Args:
results: List of search results
query: Search query used
Returns:
List of KnowledgeContext instances
"""
return [cls.from_search_result(r, query) for r in results]
def is_code(self) -> bool:
"""Check if this is code content."""
code_types = {"python", "javascript", "typescript", "go", "rust", "java", "c", "cpp"}
return self.file_type is not None and self.file_type.lower() in code_types
def is_documentation(self) -> bool:
"""Check if this is documentation content."""
doc_types = {"markdown", "rst", "txt", "md"}
return self.file_type is not None and self.file_type.lower() in doc_types
def get_formatted_source(self) -> str:
"""
Get a formatted source string for display.
Returns:
Formatted source string
"""
parts = [self.source]
if self.file_type:
parts.append(f"({self.file_type})")
if self.collection != "default":
parts.insert(0, f"[{self.collection}]")
return " ".join(parts)

View File

@@ -0,0 +1,138 @@
"""
System Context Type.
Represents system prompts, instructions, and agent personas.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import Any
from .base import BaseContext, ContextPriority, ContextType
@dataclass(eq=False)
class SystemContext(BaseContext):
"""
Context for system prompts and instructions.
System context typically includes:
- Agent persona and role definitions
- Behavioral instructions
- Safety guidelines
- Output format requirements
System context is usually high priority and should
rarely be truncated or omitted.
"""
# System context specific fields
role: str = field(default="assistant")
instructions_type: str = field(default="general")
def __post_init__(self) -> None:
"""Set high priority for system context."""
# System context defaults to high priority
if self.priority == ContextPriority.NORMAL.value:
self.priority = ContextPriority.HIGH.value
def get_type(self) -> ContextType:
"""Return SYSTEM context type."""
return ContextType.SYSTEM
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary with system-specific fields."""
base = super().to_dict()
base.update(
{
"role": self.role,
"instructions_type": self.instructions_type,
}
)
return base
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SystemContext":
"""Create SystemContext from dictionary."""
return cls(
id=data.get("id", ""),
content=data["content"],
source=data["source"],
timestamp=datetime.fromisoformat(data["timestamp"])
if isinstance(data.get("timestamp"), str)
else data.get("timestamp", datetime.now(UTC)),
priority=data.get("priority", ContextPriority.HIGH.value),
metadata=data.get("metadata", {}),
role=data.get("role", "assistant"),
instructions_type=data.get("instructions_type", "general"),
)
@classmethod
def create_persona(
cls,
name: str,
description: str,
capabilities: list[str] | None = None,
constraints: list[str] | None = None,
) -> "SystemContext":
"""
Create a persona system context.
Args:
name: Agent name/role
description: Role description
capabilities: List of things the agent can do
constraints: List of limitations
Returns:
SystemContext with formatted persona
"""
parts = [f"You are {name}.", "", description]
if capabilities:
parts.append("")
parts.append("You can:")
for cap in capabilities:
parts.append(f"- {cap}")
if constraints:
parts.append("")
parts.append("You must not:")
for constraint in constraints:
parts.append(f"- {constraint}")
return cls(
content="\n".join(parts),
source="persona_builder",
role=name.lower().replace(" ", "_"),
instructions_type="persona",
priority=ContextPriority.HIGHEST.value,
)
@classmethod
def create_instructions(
cls,
instructions: str | list[str],
source: str = "instructions",
) -> "SystemContext":
"""
Create an instructions system context.
Args:
instructions: Instructions string or list of instruction strings
source: Source identifier
Returns:
SystemContext with instructions
"""
if isinstance(instructions, list):
content = "\n".join(f"- {inst}" for inst in instructions)
else:
content = instructions
return cls(
content=content,
source=source,
instructions_type="instructions",
priority=ContextPriority.HIGH.value,
)

View File

@@ -0,0 +1,195 @@
"""
Task Context Type.
Represents the current task or objective for the agent.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from .base import BaseContext, ContextPriority, ContextType
class TaskStatus(str, Enum):
"""Status of a task."""
PENDING = "pending"
IN_PROGRESS = "in_progress"
BLOCKED = "blocked"
COMPLETED = "completed"
FAILED = "failed"
class TaskComplexity(str, Enum):
"""Complexity level of a task."""
TRIVIAL = "trivial"
SIMPLE = "simple"
MODERATE = "moderate"
COMPLEX = "complex"
VERY_COMPLEX = "very_complex"
@dataclass(eq=False)
class TaskContext(BaseContext):
"""
Context for the current task or objective.
Task context provides information about what the agent
should accomplish, including:
- Task description and goals
- Acceptance criteria
- Constraints and requirements
- Related issue/ticket information
"""
# Task-specific fields
title: str = field(default="")
status: TaskStatus = field(default=TaskStatus.PENDING)
complexity: TaskComplexity = field(default=TaskComplexity.MODERATE)
issue_id: str | None = field(default=None)
project_id: str | None = field(default=None)
acceptance_criteria: list[str] = field(default_factory=list)
constraints: list[str] = field(default_factory=list)
parent_task_id: str | None = field(default=None)
def __post_init__(self) -> None:
"""Set high priority for task context."""
# Task context defaults to high priority
if self.priority == ContextPriority.NORMAL.value:
self.priority = ContextPriority.HIGH.value
def get_type(self) -> ContextType:
"""Return TASK context type."""
return ContextType.TASK
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary with task-specific fields."""
base = super().to_dict()
base.update(
{
"title": self.title,
"status": self.status.value,
"complexity": self.complexity.value,
"issue_id": self.issue_id,
"project_id": self.project_id,
"acceptance_criteria": self.acceptance_criteria,
"constraints": self.constraints,
"parent_task_id": self.parent_task_id,
}
)
return base
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "TaskContext":
"""Create TaskContext from dictionary."""
status = data.get("status", "pending")
if isinstance(status, str):
status = TaskStatus(status)
complexity = data.get("complexity", "moderate")
if isinstance(complexity, str):
complexity = TaskComplexity(complexity)
return cls(
id=data.get("id", ""),
content=data["content"],
source=data.get("source", "task"),
timestamp=datetime.fromisoformat(data["timestamp"])
if isinstance(data.get("timestamp"), str)
else data.get("timestamp", datetime.now(UTC)),
priority=data.get("priority", ContextPriority.HIGH.value),
metadata=data.get("metadata", {}),
title=data.get("title", ""),
status=status,
complexity=complexity,
issue_id=data.get("issue_id"),
project_id=data.get("project_id"),
acceptance_criteria=data.get("acceptance_criteria", []),
constraints=data.get("constraints", []),
parent_task_id=data.get("parent_task_id"),
)
@classmethod
def create(
cls,
title: str,
description: str,
acceptance_criteria: list[str] | None = None,
constraints: list[str] | None = None,
issue_id: str | None = None,
project_id: str | None = None,
complexity: TaskComplexity | str = TaskComplexity.MODERATE,
) -> "TaskContext":
"""
Create a task context.
Args:
title: Task title
description: Task description
acceptance_criteria: List of acceptance criteria
constraints: List of constraints
issue_id: Related issue ID
project_id: Project ID
complexity: Task complexity
Returns:
TaskContext instance
"""
if isinstance(complexity, str):
complexity = TaskComplexity(complexity)
return cls(
content=description,
source=f"task:{issue_id}" if issue_id else "task",
title=title,
status=TaskStatus.IN_PROGRESS,
complexity=complexity,
issue_id=issue_id,
project_id=project_id,
acceptance_criteria=acceptance_criteria or [],
constraints=constraints or [],
)
def format_for_prompt(self) -> str:
"""
Format task for inclusion in prompt.
Returns:
Formatted task string
"""
parts = []
if self.title:
parts.append(f"Task: {self.title}")
parts.append("")
parts.append(self.content)
if self.acceptance_criteria:
parts.append("")
parts.append("Acceptance Criteria:")
for criterion in self.acceptance_criteria:
parts.append(f"- {criterion}")
if self.constraints:
parts.append("")
parts.append("Constraints:")
for constraint in self.constraints:
parts.append(f"- {constraint}")
return "\n".join(parts)
def is_active(self) -> bool:
"""Check if task is currently active."""
return self.status in (TaskStatus.PENDING, TaskStatus.IN_PROGRESS)
def is_complete(self) -> bool:
"""Check if task is complete."""
return self.status == TaskStatus.COMPLETED
def is_blocked(self) -> bool:
"""Check if task is blocked."""
return self.status == TaskStatus.BLOCKED

View File

@@ -0,0 +1,207 @@
"""
Tool Context Type.
Represents available tools and recent tool execution results.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from .base import BaseContext, ContextPriority, ContextType
class ToolResultStatus(str, Enum):
"""Status of a tool execution result."""
SUCCESS = "success"
ERROR = "error"
TIMEOUT = "timeout"
CANCELLED = "cancelled"
@dataclass(eq=False)
class ToolContext(BaseContext):
"""
Context for tools and tool execution results.
Tool context includes:
- Tool descriptions and parameters
- Recent tool execution results
- Tool availability information
This helps the LLM understand what tools are available
and what results previous tool calls produced.
"""
# Tool-specific fields
tool_name: str = field(default="")
tool_description: str = field(default="")
is_result: bool = field(default=False)
result_status: ToolResultStatus | None = field(default=None)
execution_time_ms: float | None = field(default=None)
parameters: dict[str, Any] = field(default_factory=dict)
server_name: str | None = field(default=None)
def get_type(self) -> ContextType:
"""Return TOOL context type."""
return ContextType.TOOL
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary with tool-specific fields."""
base = super().to_dict()
base.update(
{
"tool_name": self.tool_name,
"tool_description": self.tool_description,
"is_result": self.is_result,
"result_status": self.result_status.value if self.result_status else None,
"execution_time_ms": self.execution_time_ms,
"parameters": self.parameters,
"server_name": self.server_name,
}
)
return base
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ToolContext":
"""Create ToolContext from dictionary."""
result_status = data.get("result_status")
if isinstance(result_status, str):
result_status = ToolResultStatus(result_status)
return cls(
id=data.get("id", ""),
content=data["content"],
source=data.get("source", "tool"),
timestamp=datetime.fromisoformat(data["timestamp"])
if isinstance(data.get("timestamp"), str)
else data.get("timestamp", datetime.now(UTC)),
priority=data.get("priority", ContextPriority.NORMAL.value),
metadata=data.get("metadata", {}),
tool_name=data.get("tool_name", ""),
tool_description=data.get("tool_description", ""),
is_result=data.get("is_result", False),
result_status=result_status,
execution_time_ms=data.get("execution_time_ms"),
parameters=data.get("parameters", {}),
server_name=data.get("server_name"),
)
@classmethod
def from_tool_definition(
cls,
name: str,
description: str,
parameters: dict[str, Any] | None = None,
server_name: str | None = None,
) -> "ToolContext":
"""
Create a ToolContext from a tool definition.
Args:
name: Tool name
description: Tool description
parameters: Tool parameter schema
server_name: MCP server name
Returns:
ToolContext instance
"""
# Format content as tool documentation
content_parts = [f"Tool: {name}", "", description]
if parameters:
content_parts.append("")
content_parts.append("Parameters:")
for param_name, param_info in parameters.items():
param_type = param_info.get("type", "any")
param_desc = param_info.get("description", "")
required = param_info.get("required", False)
req_marker = " (required)" if required else ""
content_parts.append(f" - {param_name}: {param_type}{req_marker}")
if param_desc:
content_parts.append(f" {param_desc}")
return cls(
content="\n".join(content_parts),
source=f"tool:{server_name}:{name}" if server_name else f"tool:{name}",
tool_name=name,
tool_description=description,
is_result=False,
parameters=parameters or {},
server_name=server_name,
priority=ContextPriority.LOW.value,
)
@classmethod
def from_tool_result(
cls,
tool_name: str,
result: Any,
status: ToolResultStatus = ToolResultStatus.SUCCESS,
execution_time_ms: float | None = None,
parameters: dict[str, Any] | None = None,
server_name: str | None = None,
) -> "ToolContext":
"""
Create a ToolContext from a tool execution result.
Args:
tool_name: Name of the tool that was executed
result: Result content (will be converted to string)
status: Execution status
execution_time_ms: Execution time in milliseconds
parameters: Parameters that were passed to the tool
server_name: MCP server name
Returns:
ToolContext instance
"""
# Convert result to string content
if isinstance(result, str):
content = result
elif isinstance(result, dict):
import json
try:
content = json.dumps(result, indent=2)
except (TypeError, ValueError):
content = str(result)
else:
content = str(result)
return cls(
content=content,
source=f"tool_result:{server_name}:{tool_name}" if server_name else f"tool_result:{tool_name}",
tool_name=tool_name,
is_result=True,
result_status=status,
execution_time_ms=execution_time_ms,
parameters=parameters or {},
server_name=server_name,
priority=ContextPriority.HIGH.value, # Recent results are high priority
)
def is_successful(self) -> bool:
"""Check if this is a successful tool result."""
return self.is_result and self.result_status == ToolResultStatus.SUCCESS
def is_error(self) -> bool:
"""Check if this is an error result."""
return self.is_result and self.result_status == ToolResultStatus.ERROR
def format_for_prompt(self) -> str:
"""
Format tool context for inclusion in prompt.
Returns:
Formatted tool string
"""
if self.is_result:
status_str = self.result_status.value if self.result_status else "unknown"
header = f"Tool Result ({self.tool_name}, {status_str}):"
return f"{header}\n{self.content}"
else:
return self.content

View File

@@ -0,0 +1 @@
"""Tests for Context Management Engine."""

View File

@@ -0,0 +1,243 @@
"""Tests for context management configuration."""
import os
from unittest.mock import patch
import pytest
from app.services.context.config import (
ContextSettings,
get_context_settings,
get_default_settings,
reset_context_settings,
)
class TestContextSettings:
"""Tests for ContextSettings."""
def test_default_values(self) -> None:
"""Test default settings values."""
settings = ContextSettings()
# Budget defaults should sum to 1.0
total = (
settings.budget_system
+ settings.budget_task
+ settings.budget_knowledge
+ settings.budget_conversation
+ settings.budget_tools
+ settings.budget_response
+ settings.budget_buffer
)
assert abs(total - 1.0) < 0.001
# Scoring weights should sum to 1.0
weights_total = (
settings.scoring_relevance_weight
+ settings.scoring_recency_weight
+ settings.scoring_priority_weight
)
assert abs(weights_total - 1.0) < 0.001
def test_budget_allocation_values(self) -> None:
"""Test specific budget allocation values."""
settings = ContextSettings()
assert settings.budget_system == 0.05
assert settings.budget_task == 0.10
assert settings.budget_knowledge == 0.40
assert settings.budget_conversation == 0.20
assert settings.budget_tools == 0.05
assert settings.budget_response == 0.15
assert settings.budget_buffer == 0.05
def test_scoring_weights(self) -> None:
"""Test scoring weights."""
settings = ContextSettings()
assert settings.scoring_relevance_weight == 0.5
assert settings.scoring_recency_weight == 0.3
assert settings.scoring_priority_weight == 0.2
def test_cache_settings(self) -> None:
"""Test cache settings."""
settings = ContextSettings()
assert settings.cache_enabled is True
assert settings.cache_ttl_seconds == 3600
assert settings.cache_prefix == "ctx"
def test_performance_settings(self) -> None:
"""Test performance settings."""
settings = ContextSettings()
assert settings.max_assembly_time_ms == 100
assert settings.parallel_scoring is True
assert settings.max_parallel_scores == 10
def test_get_budget_allocation(self) -> None:
"""Test get_budget_allocation method."""
settings = ContextSettings()
allocation = settings.get_budget_allocation()
assert isinstance(allocation, dict)
assert "system" in allocation
assert "knowledge" in allocation
assert allocation["system"] == 0.05
assert allocation["knowledge"] == 0.40
def test_get_scoring_weights(self) -> None:
"""Test get_scoring_weights method."""
settings = ContextSettings()
weights = settings.get_scoring_weights()
assert isinstance(weights, dict)
assert "relevance" in weights
assert "recency" in weights
assert "priority" in weights
assert weights["relevance"] == 0.5
def test_to_dict(self) -> None:
"""Test to_dict method."""
settings = ContextSettings()
result = settings.to_dict()
assert "budget" in result
assert "scoring" in result
assert "compression" in result
assert "cache" in result
assert "performance" in result
assert "knowledge" in result
assert "conversation" in result
def test_budget_validation_fails_on_wrong_sum(self) -> None:
"""Test that budget validation fails when sum != 1.0."""
with pytest.raises(ValueError) as exc_info:
ContextSettings(
budget_system=0.5,
budget_task=0.5,
# Other budgets default to non-zero, so total > 1.0
)
assert "sum to 1.0" in str(exc_info.value)
def test_scoring_validation_fails_on_wrong_sum(self) -> None:
"""Test that scoring validation fails when sum != 1.0."""
with pytest.raises(ValueError) as exc_info:
ContextSettings(
scoring_relevance_weight=0.8,
scoring_recency_weight=0.8,
scoring_priority_weight=0.8,
)
assert "sum to 1.0" in str(exc_info.value)
def test_search_type_validation(self) -> None:
"""Test search type validation."""
# Valid types should work
ContextSettings(knowledge_search_type="semantic")
ContextSettings(knowledge_search_type="keyword")
ContextSettings(knowledge_search_type="hybrid")
# Invalid type should fail
with pytest.raises(ValueError):
ContextSettings(knowledge_search_type="invalid")
def test_custom_budget_allocation(self) -> None:
"""Test custom budget allocation that sums to 1.0."""
settings = ContextSettings(
budget_system=0.10,
budget_task=0.10,
budget_knowledge=0.30,
budget_conversation=0.25,
budget_tools=0.05,
budget_response=0.15,
budget_buffer=0.05,
)
total = (
settings.budget_system
+ settings.budget_task
+ settings.budget_knowledge
+ settings.budget_conversation
+ settings.budget_tools
+ settings.budget_response
+ settings.budget_buffer
)
assert abs(total - 1.0) < 0.001
class TestSettingsSingleton:
"""Tests for settings singleton pattern."""
def setup_method(self) -> None:
"""Reset settings before each test."""
reset_context_settings()
def teardown_method(self) -> None:
"""Clean up after each test."""
reset_context_settings()
def test_get_context_settings_returns_instance(self) -> None:
"""Test that get_context_settings returns a settings instance."""
settings = get_context_settings()
assert isinstance(settings, ContextSettings)
def test_get_context_settings_returns_same_instance(self) -> None:
"""Test that get_context_settings returns the same instance."""
settings1 = get_context_settings()
settings2 = get_context_settings()
assert settings1 is settings2
def test_reset_creates_new_instance(self) -> None:
"""Test that reset creates a new instance."""
settings1 = get_context_settings()
reset_context_settings()
settings2 = get_context_settings()
# Should be different instances
assert settings1 is not settings2
def test_get_default_settings_cached(self) -> None:
"""Test that get_default_settings is cached."""
settings1 = get_default_settings()
settings2 = get_default_settings()
assert settings1 is settings2
class TestEnvironmentOverrides:
"""Tests for environment variable overrides."""
def setup_method(self) -> None:
"""Reset settings before each test."""
reset_context_settings()
def teardown_method(self) -> None:
"""Clean up after each test."""
reset_context_settings()
# Clean up any env vars we set
for key in list(os.environ.keys()):
if key.startswith("CTX_"):
del os.environ[key]
def test_env_override_cache_enabled(self) -> None:
"""Test that CTX_CACHE_ENABLED env var works."""
with patch.dict(os.environ, {"CTX_CACHE_ENABLED": "false"}):
reset_context_settings()
settings = ContextSettings()
assert settings.cache_enabled is False
def test_env_override_cache_ttl(self) -> None:
"""Test that CTX_CACHE_TTL_SECONDS env var works."""
with patch.dict(os.environ, {"CTX_CACHE_TTL_SECONDS": "7200"}):
reset_context_settings()
settings = ContextSettings()
assert settings.cache_ttl_seconds == 7200
def test_env_override_max_assembly_time(self) -> None:
"""Test that CTX_MAX_ASSEMBLY_TIME_MS env var works."""
with patch.dict(os.environ, {"CTX_MAX_ASSEMBLY_TIME_MS": "200"}):
reset_context_settings()
settings = ContextSettings()
assert settings.max_assembly_time_ms == 200

View File

@@ -0,0 +1,252 @@
"""Tests for context management exceptions."""
import pytest
from app.services.context.exceptions import (
AssemblyTimeoutError,
BudgetExceededError,
CacheError,
CompressionError,
ContextError,
ContextNotFoundError,
FormattingError,
InvalidContextError,
ScoringError,
TokenCountError,
)
class TestContextError:
"""Tests for base ContextError."""
def test_basic_initialization(self) -> None:
"""Test basic error initialization."""
error = ContextError("Test error")
assert error.message == "Test error"
assert error.details == {}
assert str(error) == "Test error"
def test_with_details(self) -> None:
"""Test error with details."""
error = ContextError("Test error", {"key": "value", "count": 42})
assert error.details == {"key": "value", "count": 42}
def test_to_dict(self) -> None:
"""Test conversion to dictionary."""
error = ContextError("Test error", {"key": "value"})
result = error.to_dict()
assert result["error_type"] == "ContextError"
assert result["message"] == "Test error"
assert result["details"] == {"key": "value"}
def test_inheritance(self) -> None:
"""Test that ContextError inherits from Exception."""
error = ContextError("Test")
assert isinstance(error, Exception)
class TestBudgetExceededError:
"""Tests for BudgetExceededError."""
def test_default_message(self) -> None:
"""Test default error message."""
error = BudgetExceededError()
assert "exceeded" in error.message.lower()
def test_with_budget_info(self) -> None:
"""Test with budget information."""
error = BudgetExceededError(
allocated=1000,
requested=1500,
context_type="knowledge",
)
assert error.allocated == 1000
assert error.requested == 1500
assert error.context_type == "knowledge"
assert error.details["overage"] == 500
def test_to_dict_includes_budget_info(self) -> None:
"""Test that to_dict includes budget info."""
error = BudgetExceededError(
allocated=1000,
requested=1500,
)
result = error.to_dict()
assert result["details"]["allocated"] == 1000
assert result["details"]["requested"] == 1500
assert result["details"]["overage"] == 500
class TestTokenCountError:
"""Tests for TokenCountError."""
def test_basic_error(self) -> None:
"""Test basic token count error."""
error = TokenCountError()
assert "token" in error.message.lower()
def test_with_model_info(self) -> None:
"""Test with model information."""
error = TokenCountError(
message="Failed to count",
model="claude-3-sonnet",
text_length=5000,
)
assert error.model == "claude-3-sonnet"
assert error.text_length == 5000
assert error.details["model"] == "claude-3-sonnet"
class TestCompressionError:
"""Tests for CompressionError."""
def test_basic_error(self) -> None:
"""Test basic compression error."""
error = CompressionError()
assert "compress" in error.message.lower()
def test_with_token_info(self) -> None:
"""Test with token information."""
error = CompressionError(
original_tokens=2000,
target_tokens=1000,
achieved_tokens=1500,
)
assert error.original_tokens == 2000
assert error.target_tokens == 1000
assert error.achieved_tokens == 1500
class TestAssemblyTimeoutError:
"""Tests for AssemblyTimeoutError."""
def test_basic_error(self) -> None:
"""Test basic timeout error."""
error = AssemblyTimeoutError()
assert "timed out" in error.message.lower()
def test_with_timing_info(self) -> None:
"""Test with timing information."""
error = AssemblyTimeoutError(
timeout_ms=100,
elapsed_ms=150.5,
stage="scoring",
)
assert error.timeout_ms == 100
assert error.elapsed_ms == 150.5
assert error.stage == "scoring"
assert error.details["stage"] == "scoring"
class TestScoringError:
"""Tests for ScoringError."""
def test_basic_error(self) -> None:
"""Test basic scoring error."""
error = ScoringError()
assert "score" in error.message.lower()
def test_with_scorer_info(self) -> None:
"""Test with scorer information."""
error = ScoringError(
scorer_type="relevance",
context_id="ctx-123",
)
assert error.scorer_type == "relevance"
assert error.context_id == "ctx-123"
class TestFormattingError:
"""Tests for FormattingError."""
def test_basic_error(self) -> None:
"""Test basic formatting error."""
error = FormattingError()
assert "format" in error.message.lower()
def test_with_model_info(self) -> None:
"""Test with model information."""
error = FormattingError(
model="claude-3-opus",
adapter="ClaudeAdapter",
)
assert error.model == "claude-3-opus"
assert error.adapter == "ClaudeAdapter"
class TestCacheError:
"""Tests for CacheError."""
def test_basic_error(self) -> None:
"""Test basic cache error."""
error = CacheError()
assert "cache" in error.message.lower()
def test_with_operation_info(self) -> None:
"""Test with operation information."""
error = CacheError(
operation="get",
cache_key="ctx:abc123",
)
assert error.operation == "get"
assert error.cache_key == "ctx:abc123"
class TestContextNotFoundError:
"""Tests for ContextNotFoundError."""
def test_basic_error(self) -> None:
"""Test basic not found error."""
error = ContextNotFoundError()
assert "not found" in error.message.lower()
def test_with_source_info(self) -> None:
"""Test with source information."""
error = ContextNotFoundError(
source="knowledge-base",
query="authentication flow",
)
assert error.source == "knowledge-base"
assert error.query == "authentication flow"
class TestInvalidContextError:
"""Tests for InvalidContextError."""
def test_basic_error(self) -> None:
"""Test basic invalid error."""
error = InvalidContextError()
assert "invalid" in error.message.lower()
def test_with_field_info(self) -> None:
"""Test with field information."""
error = InvalidContextError(
field="content",
value="",
reason="Content cannot be empty",
)
assert error.field == "content"
assert error.value == ""
assert error.reason == "Content cannot be empty"
def test_value_type_only_in_details(self) -> None:
"""Test that only value type is included in details (not actual value)."""
error = InvalidContextError(
field="api_key",
value="secret-key-here",
)
# Actual value should not be in details
assert "secret-key-here" not in str(error.details)
assert error.details["value_type"] == "str"

View File

@@ -0,0 +1,579 @@
"""Tests for context types."""
import json
from datetime import UTC, datetime, timedelta
import pytest
from app.services.context.types import (
AssembledContext,
BaseContext,
ContextPriority,
ContextType,
ConversationContext,
KnowledgeContext,
MessageRole,
SystemContext,
TaskComplexity,
TaskContext,
TaskStatus,
ToolContext,
ToolResultStatus,
)
class TestContextType:
"""Tests for ContextType enum."""
def test_all_types_exist(self) -> None:
"""Test that all expected context types exist."""
assert ContextType.SYSTEM
assert ContextType.TASK
assert ContextType.KNOWLEDGE
assert ContextType.CONVERSATION
assert ContextType.TOOL
def test_from_string_valid(self) -> None:
"""Test from_string with valid values."""
assert ContextType.from_string("system") == ContextType.SYSTEM
assert ContextType.from_string("KNOWLEDGE") == ContextType.KNOWLEDGE
assert ContextType.from_string("Task") == ContextType.TASK
def test_from_string_invalid(self) -> None:
"""Test from_string with invalid value."""
with pytest.raises(ValueError) as exc_info:
ContextType.from_string("invalid")
assert "Invalid context type" in str(exc_info.value)
class TestContextPriority:
"""Tests for ContextPriority enum."""
def test_priority_ordering(self) -> None:
"""Test that priorities are ordered correctly."""
assert ContextPriority.LOWEST.value < ContextPriority.LOW.value
assert ContextPriority.LOW.value < ContextPriority.NORMAL.value
assert ContextPriority.NORMAL.value < ContextPriority.HIGH.value
assert ContextPriority.HIGH.value < ContextPriority.HIGHEST.value
assert ContextPriority.HIGHEST.value < ContextPriority.CRITICAL.value
def test_from_int(self) -> None:
"""Test from_int conversion."""
assert ContextPriority.from_int(0) == ContextPriority.LOWEST
assert ContextPriority.from_int(50) == ContextPriority.NORMAL
assert ContextPriority.from_int(100) == ContextPriority.HIGHEST
assert ContextPriority.from_int(200) == ContextPriority.CRITICAL
def test_from_int_intermediate(self) -> None:
"""Test from_int with intermediate values."""
# Should return closest lower priority
assert ContextPriority.from_int(30) == ContextPriority.LOW
assert ContextPriority.from_int(60) == ContextPriority.NORMAL
class TestSystemContext:
"""Tests for SystemContext."""
def test_creation(self) -> None:
"""Test basic creation."""
ctx = SystemContext(
content="You are a helpful assistant.",
source="system_prompt",
)
assert ctx.content == "You are a helpful assistant."
assert ctx.source == "system_prompt"
assert ctx.get_type() == ContextType.SYSTEM
def test_default_high_priority(self) -> None:
"""Test that system context defaults to high priority."""
ctx = SystemContext(content="Test", source="test")
assert ctx.priority == ContextPriority.HIGH.value
def test_create_persona(self) -> None:
"""Test create_persona factory method."""
ctx = SystemContext.create_persona(
name="Code Assistant",
description="A helpful coding assistant.",
capabilities=["Write code", "Debug"],
constraints=["Never expose secrets"],
)
assert "Code Assistant" in ctx.content
assert "helpful coding assistant" in ctx.content
assert "Write code" in ctx.content
assert "Never expose secrets" in ctx.content
assert ctx.priority == ContextPriority.HIGHEST.value
def test_create_instructions(self) -> None:
"""Test create_instructions factory method."""
ctx = SystemContext.create_instructions(
["Always be helpful", "Be concise"],
source="rules",
)
assert "Always be helpful" in ctx.content
assert "Be concise" in ctx.content
def test_to_dict(self) -> None:
"""Test serialization to dict."""
ctx = SystemContext(
content="Test",
source="test",
role="assistant",
instructions_type="general",
)
data = ctx.to_dict()
assert data["role"] == "assistant"
assert data["instructions_type"] == "general"
assert data["type"] == "system"
class TestKnowledgeContext:
"""Tests for KnowledgeContext."""
def test_creation(self) -> None:
"""Test basic creation."""
ctx = KnowledgeContext(
content="def authenticate(user): ...",
source="/src/auth.py",
collection="code",
file_type="python",
)
assert ctx.content == "def authenticate(user): ..."
assert ctx.source == "/src/auth.py"
assert ctx.collection == "code"
assert ctx.get_type() == ContextType.KNOWLEDGE
def test_from_search_result(self) -> None:
"""Test from_search_result factory method."""
result = {
"content": "Test content",
"source_path": "/test/file.py",
"collection": "code",
"file_type": "python",
"chunk_index": 2,
"score": 0.85,
"id": "chunk-123",
}
ctx = KnowledgeContext.from_search_result(result, "test query")
assert ctx.content == "Test content"
assert ctx.source == "/test/file.py"
assert ctx.relevance_score == 0.85
assert ctx.search_query == "test query"
def test_from_search_results(self) -> None:
"""Test from_search_results factory method."""
results = [
{"content": "Content 1", "source_path": "/a.py", "score": 0.9},
{"content": "Content 2", "source_path": "/b.py", "score": 0.8},
]
contexts = KnowledgeContext.from_search_results(results, "query")
assert len(contexts) == 2
assert contexts[0].relevance_score == 0.9
assert contexts[1].source == "/b.py"
def test_is_code(self) -> None:
"""Test is_code method."""
code_ctx = KnowledgeContext(
content="code", source="test", file_type="python"
)
doc_ctx = KnowledgeContext(
content="docs", source="test", file_type="markdown"
)
assert code_ctx.is_code() is True
assert doc_ctx.is_code() is False
def test_is_documentation(self) -> None:
"""Test is_documentation method."""
doc_ctx = KnowledgeContext(
content="docs", source="test", file_type="markdown"
)
code_ctx = KnowledgeContext(
content="code", source="test", file_type="python"
)
assert doc_ctx.is_documentation() is True
assert code_ctx.is_documentation() is False
class TestConversationContext:
"""Tests for ConversationContext."""
def test_creation(self) -> None:
"""Test basic creation."""
ctx = ConversationContext(
content="Hello, how can I help?",
source="conversation",
role=MessageRole.ASSISTANT,
turn_index=1,
)
assert ctx.content == "Hello, how can I help?"
assert ctx.role == MessageRole.ASSISTANT
assert ctx.get_type() == ContextType.CONVERSATION
def test_from_message(self) -> None:
"""Test from_message factory method."""
ctx = ConversationContext.from_message(
content="What is Python?",
role="user",
turn_index=0,
)
assert ctx.content == "What is Python?"
assert ctx.role == MessageRole.USER
assert ctx.turn_index == 0
def test_from_history(self) -> None:
"""Test from_history factory method."""
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "Help me"},
]
contexts = ConversationContext.from_history(messages)
assert len(contexts) == 3
assert contexts[0].role == MessageRole.USER
assert contexts[1].role == MessageRole.ASSISTANT
assert contexts[2].turn_index == 2
def test_is_user_message(self) -> None:
"""Test is_user_message method."""
user_ctx = ConversationContext(
content="test", source="test", role=MessageRole.USER
)
assistant_ctx = ConversationContext(
content="test", source="test", role=MessageRole.ASSISTANT
)
assert user_ctx.is_user_message() is True
assert assistant_ctx.is_user_message() is False
def test_format_for_prompt(self) -> None:
"""Test format_for_prompt method."""
ctx = ConversationContext.from_message(
content="What is 2+2?",
role="user",
)
formatted = ctx.format_for_prompt()
assert "User:" in formatted
assert "What is 2+2?" in formatted
class TestTaskContext:
"""Tests for TaskContext."""
def test_creation(self) -> None:
"""Test basic creation."""
ctx = TaskContext(
content="Implement login feature",
source="task",
title="Login Feature",
)
assert ctx.content == "Implement login feature"
assert ctx.title == "Login Feature"
assert ctx.get_type() == ContextType.TASK
def test_default_high_priority(self) -> None:
"""Test that task context defaults to high priority."""
ctx = TaskContext(content="Test", source="test")
assert ctx.priority == ContextPriority.HIGH.value
def test_create_factory(self) -> None:
"""Test create factory method."""
ctx = TaskContext.create(
title="Add Auth",
description="Implement authentication",
acceptance_criteria=["Tests pass", "Code reviewed"],
constraints=["Use JWT"],
issue_id="123",
)
assert ctx.title == "Add Auth"
assert ctx.content == "Implement authentication"
assert len(ctx.acceptance_criteria) == 2
assert "Use JWT" in ctx.constraints
assert ctx.status == TaskStatus.IN_PROGRESS
def test_format_for_prompt(self) -> None:
"""Test format_for_prompt method."""
ctx = TaskContext.create(
title="Test Task",
description="Do something",
acceptance_criteria=["Works correctly"],
)
formatted = ctx.format_for_prompt()
assert "Task: Test Task" in formatted
assert "Do something" in formatted
assert "Works correctly" in formatted
def test_status_checks(self) -> None:
"""Test status check methods."""
pending = TaskContext(
content="test", source="test", status=TaskStatus.PENDING
)
completed = TaskContext(
content="test", source="test", status=TaskStatus.COMPLETED
)
blocked = TaskContext(
content="test", source="test", status=TaskStatus.BLOCKED
)
assert pending.is_active() is True
assert completed.is_complete() is True
assert blocked.is_blocked() is True
class TestToolContext:
"""Tests for ToolContext."""
def test_creation(self) -> None:
"""Test basic creation."""
ctx = ToolContext(
content="Tool result here",
source="tool:search",
tool_name="search",
)
assert ctx.tool_name == "search"
assert ctx.get_type() == ContextType.TOOL
def test_from_tool_definition(self) -> None:
"""Test from_tool_definition factory method."""
ctx = ToolContext.from_tool_definition(
name="search_knowledge",
description="Search the knowledge base",
parameters={
"query": {"type": "string", "required": True},
"limit": {"type": "integer", "required": False},
},
server_name="knowledge-base",
)
assert ctx.tool_name == "search_knowledge"
assert "Search the knowledge base" in ctx.content
assert ctx.is_result is False
assert ctx.server_name == "knowledge-base"
def test_from_tool_result(self) -> None:
"""Test from_tool_result factory method."""
ctx = ToolContext.from_tool_result(
tool_name="search",
result={"found": 5, "items": ["a", "b"]},
status=ToolResultStatus.SUCCESS,
execution_time_ms=150.5,
)
assert ctx.tool_name == "search"
assert ctx.is_result is True
assert ctx.result_status == ToolResultStatus.SUCCESS
assert "found" in ctx.content
def test_is_successful(self) -> None:
"""Test is_successful method."""
success = ToolContext.from_tool_result(
"test", "ok", ToolResultStatus.SUCCESS
)
error = ToolContext.from_tool_result(
"test", "error", ToolResultStatus.ERROR
)
assert success.is_successful() is True
assert error.is_successful() is False
def test_format_for_prompt(self) -> None:
"""Test format_for_prompt method."""
ctx = ToolContext.from_tool_result(
"search",
"Found 3 results",
ToolResultStatus.SUCCESS,
)
formatted = ctx.format_for_prompt()
assert "Tool Result" in formatted
assert "search" in formatted
assert "success" in formatted
class TestAssembledContext:
"""Tests for AssembledContext."""
def test_creation(self) -> None:
"""Test basic creation."""
ctx = AssembledContext(
content="Assembled content here",
token_count=500,
contexts_included=5,
)
assert ctx.content == "Assembled content here"
assert ctx.token_count == 500
assert ctx.contexts_included == 5
def test_budget_utilization(self) -> None:
"""Test budget_utilization property."""
ctx = AssembledContext(
content="test",
token_count=800,
contexts_included=5,
budget_total=1000,
budget_used=800,
)
assert ctx.budget_utilization == 0.8
def test_budget_utilization_zero_budget(self) -> None:
"""Test budget_utilization with zero budget."""
ctx = AssembledContext(
content="test",
token_count=0,
contexts_included=0,
budget_total=0,
budget_used=0,
)
assert ctx.budget_utilization == 0.0
def test_to_dict(self) -> None:
"""Test to_dict method."""
ctx = AssembledContext(
content="test",
token_count=100,
contexts_included=2,
assembly_time_ms=50.123,
)
data = ctx.to_dict()
assert data["content"] == "test"
assert data["token_count"] == 100
assert data["assembly_time_ms"] == 50.12 # Rounded
def test_to_json_and_from_json(self) -> None:
"""Test JSON serialization round-trip."""
original = AssembledContext(
content="Test content",
token_count=100,
contexts_included=3,
contexts_excluded=2,
assembly_time_ms=45.5,
budget_total=1000,
budget_used=100,
by_type={"system": 20, "knowledge": 80},
cache_hit=True,
cache_key="abc123",
)
json_str = original.to_json()
restored = AssembledContext.from_json(json_str)
assert restored.content == original.content
assert restored.token_count == original.token_count
assert restored.contexts_included == original.contexts_included
assert restored.cache_hit == original.cache_hit
assert restored.cache_key == original.cache_key
class TestBaseContextMethods:
"""Tests for BaseContext methods."""
def test_get_age_seconds(self) -> None:
"""Test get_age_seconds method."""
old_time = datetime.now(UTC) - timedelta(hours=2)
ctx = SystemContext(
content="test", source="test", timestamp=old_time
)
age = ctx.get_age_seconds()
# Should be approximately 2 hours in seconds
assert 7100 < age < 7300
def test_get_age_hours(self) -> None:
"""Test get_age_hours method."""
old_time = datetime.now(UTC) - timedelta(hours=5)
ctx = SystemContext(
content="test", source="test", timestamp=old_time
)
age = ctx.get_age_hours()
assert 4.9 < age < 5.1
def test_is_stale(self) -> None:
"""Test is_stale method."""
old_time = datetime.now(UTC) - timedelta(days=10)
new_time = datetime.now(UTC) - timedelta(hours=1)
old_ctx = SystemContext(
content="test", source="test", timestamp=old_time
)
new_ctx = SystemContext(
content="test", source="test", timestamp=new_time
)
# Default max_age is 168 hours (7 days)
assert old_ctx.is_stale() is True
assert new_ctx.is_stale() is False
def test_token_count_property(self) -> None:
"""Test token_count property."""
ctx = SystemContext(content="test", source="test")
# Initially None
assert ctx.token_count is None
# Can be set
ctx.token_count = 100
assert ctx.token_count == 100
def test_score_property_clamping(self) -> None:
"""Test that score is clamped to 0.0-1.0."""
ctx = SystemContext(content="test", source="test")
ctx.score = 1.5
assert ctx.score == 1.0
ctx.score = -0.5
assert ctx.score == 0.0
ctx.score = 0.75
assert ctx.score == 0.75
def test_hash_and_equality(self) -> None:
"""Test hash and equality based on ID."""
ctx1 = SystemContext(content="test", source="test")
ctx2 = SystemContext(content="test", source="test")
ctx3 = SystemContext(content="test", source="test")
ctx3.id = ctx1.id # Same ID as ctx1
# Different IDs = not equal
assert ctx1 != ctx2
# Same ID = equal
assert ctx1 == ctx3
# Can be used in sets
context_set = {ctx1, ctx2, ctx3}
assert len(context_set) == 2 # ctx1 and ctx3 are same
def test_truncate(self) -> None:
"""Test truncate method."""
long_content = "word " * 1000 # Long content
ctx = SystemContext(content=long_content, source="test")
ctx.token_count = 1000
truncated = ctx.truncate(100)
assert len(truncated) < len(long_content)
assert "[truncated]" in truncated