forked from cardosofelipe/fast-next-template
Compare commits
14 Commits
2ab69f8561
...
60ebeaa582
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60ebeaa582 | ||
|
|
758052dcff | ||
|
|
1628eacf2b | ||
|
|
2bea057fb1 | ||
|
|
9e54f16e56 | ||
|
|
96e6400bd8 | ||
|
|
6c7b72f130 | ||
|
|
027ebfc332 | ||
|
|
c2466ab401 | ||
|
|
7828d35e06 | ||
|
|
6b07e62f00 | ||
|
|
0d2005ddcb | ||
|
|
dfa75e682e | ||
|
|
22ecb5e989 |
178
backend/app/services/context/__init__.py
Normal file
178
backend/app/services/context/__init__.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""
|
||||||
|
Context Management Engine
|
||||||
|
|
||||||
|
Sophisticated context assembly and optimization for LLM requests.
|
||||||
|
Provides intelligent context selection, token budget management,
|
||||||
|
and model-specific formatting.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.services.context import (
|
||||||
|
ContextSettings,
|
||||||
|
get_context_settings,
|
||||||
|
SystemContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
ConversationContext,
|
||||||
|
TaskContext,
|
||||||
|
ToolContext,
|
||||||
|
TokenBudget,
|
||||||
|
BudgetAllocator,
|
||||||
|
TokenCalculator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get settings
|
||||||
|
settings = get_context_settings()
|
||||||
|
|
||||||
|
# Create budget for a model
|
||||||
|
allocator = BudgetAllocator(settings)
|
||||||
|
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||||
|
|
||||||
|
# Create context instances
|
||||||
|
system_ctx = SystemContext.create_persona(
|
||||||
|
name="Code Assistant",
|
||||||
|
description="You are a helpful code assistant.",
|
||||||
|
capabilities=["Write code", "Debug issues"],
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Budget Management
|
||||||
|
# Adapters
|
||||||
|
from .adapters import (
|
||||||
|
ClaudeAdapter,
|
||||||
|
DefaultAdapter,
|
||||||
|
ModelAdapter,
|
||||||
|
OpenAIAdapter,
|
||||||
|
get_adapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assembly
|
||||||
|
from .assembly import (
|
||||||
|
ContextPipeline,
|
||||||
|
PipelineMetrics,
|
||||||
|
)
|
||||||
|
from .budget import (
|
||||||
|
BudgetAllocator,
|
||||||
|
TokenBudget,
|
||||||
|
TokenCalculator,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache
|
||||||
|
from .cache import ContextCache
|
||||||
|
|
||||||
|
# Compression
|
||||||
|
from .compression import (
|
||||||
|
ContextCompressor,
|
||||||
|
TruncationResult,
|
||||||
|
TruncationStrategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
from .config import (
|
||||||
|
ContextSettings,
|
||||||
|
get_context_settings,
|
||||||
|
get_default_settings,
|
||||||
|
reset_context_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Engine
|
||||||
|
from .engine import ContextEngine, create_context_engine
|
||||||
|
|
||||||
|
# Exceptions
|
||||||
|
from .exceptions import (
|
||||||
|
AssemblyTimeoutError,
|
||||||
|
BudgetExceededError,
|
||||||
|
CacheError,
|
||||||
|
CompressionError,
|
||||||
|
ContextError,
|
||||||
|
ContextNotFoundError,
|
||||||
|
FormattingError,
|
||||||
|
InvalidContextError,
|
||||||
|
ScoringError,
|
||||||
|
TokenCountError,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prioritization
|
||||||
|
from .prioritization import (
|
||||||
|
ContextRanker,
|
||||||
|
RankingResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scoring
|
||||||
|
from .scoring import (
|
||||||
|
BaseScorer,
|
||||||
|
CompositeScorer,
|
||||||
|
PriorityScorer,
|
||||||
|
RecencyScorer,
|
||||||
|
RelevanceScorer,
|
||||||
|
ScoredContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Types
|
||||||
|
from .types import (
|
||||||
|
AssembledContext,
|
||||||
|
BaseContext,
|
||||||
|
ContextPriority,
|
||||||
|
ContextType,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskComplexity,
|
||||||
|
TaskContext,
|
||||||
|
TaskStatus,
|
||||||
|
ToolContext,
|
||||||
|
ToolResultStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AssembledContext",
|
||||||
|
"AssemblyTimeoutError",
|
||||||
|
"BaseContext",
|
||||||
|
"BaseScorer",
|
||||||
|
"BudgetAllocator",
|
||||||
|
"BudgetExceededError",
|
||||||
|
"CacheError",
|
||||||
|
"ClaudeAdapter",
|
||||||
|
"CompositeScorer",
|
||||||
|
"CompressionError",
|
||||||
|
"ContextCache",
|
||||||
|
"ContextCompressor",
|
||||||
|
"ContextEngine",
|
||||||
|
"ContextError",
|
||||||
|
"ContextNotFoundError",
|
||||||
|
"ContextPipeline",
|
||||||
|
"ContextPriority",
|
||||||
|
"ContextRanker",
|
||||||
|
"ContextSettings",
|
||||||
|
"ContextType",
|
||||||
|
"ConversationContext",
|
||||||
|
"DefaultAdapter",
|
||||||
|
"FormattingError",
|
||||||
|
"InvalidContextError",
|
||||||
|
"KnowledgeContext",
|
||||||
|
"MessageRole",
|
||||||
|
"ModelAdapter",
|
||||||
|
"OpenAIAdapter",
|
||||||
|
"PipelineMetrics",
|
||||||
|
"PriorityScorer",
|
||||||
|
"RankingResult",
|
||||||
|
"RecencyScorer",
|
||||||
|
"RelevanceScorer",
|
||||||
|
"ScoredContext",
|
||||||
|
"ScoringError",
|
||||||
|
"SystemContext",
|
||||||
|
"TaskComplexity",
|
||||||
|
"TaskContext",
|
||||||
|
"TaskStatus",
|
||||||
|
"TokenBudget",
|
||||||
|
"TokenCalculator",
|
||||||
|
"TokenCountError",
|
||||||
|
"ToolContext",
|
||||||
|
"ToolResultStatus",
|
||||||
|
"TruncationResult",
|
||||||
|
"TruncationStrategy",
|
||||||
|
"create_context_engine",
|
||||||
|
"get_adapter",
|
||||||
|
"get_context_settings",
|
||||||
|
"get_default_settings",
|
||||||
|
"reset_context_settings",
|
||||||
|
]
|
||||||
35
backend/app/services/context/adapters/__init__.py
Normal file
35
backend/app/services/context/adapters/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
Model Adapters Module.
|
||||||
|
|
||||||
|
Provides model-specific context formatting adapters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import DefaultAdapter, ModelAdapter
|
||||||
|
from .claude import ClaudeAdapter
|
||||||
|
from .openai import OpenAIAdapter
|
||||||
|
|
||||||
|
|
||||||
|
def get_adapter(model: str) -> ModelAdapter:
|
||||||
|
"""
|
||||||
|
Get the appropriate adapter for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adapter instance for the model
|
||||||
|
"""
|
||||||
|
if ClaudeAdapter.matches_model(model):
|
||||||
|
return ClaudeAdapter()
|
||||||
|
elif OpenAIAdapter.matches_model(model):
|
||||||
|
return OpenAIAdapter()
|
||||||
|
return DefaultAdapter()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClaudeAdapter",
|
||||||
|
"DefaultAdapter",
|
||||||
|
"ModelAdapter",
|
||||||
|
"OpenAIAdapter",
|
||||||
|
"get_adapter",
|
||||||
|
]
|
||||||
178
backend/app/services/context/adapters/base.py
Normal file
178
backend/app/services/context/adapters/base.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""
|
||||||
|
Base Model Adapter.
|
||||||
|
|
||||||
|
Abstract base class for model-specific context formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class ModelAdapter(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base adapter for model-specific context formatting.
|
||||||
|
|
||||||
|
Each adapter knows how to format contexts for optimal
|
||||||
|
understanding by a specific LLM family (Claude, OpenAI, etc.).
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Model name patterns this adapter handles
|
||||||
|
MODEL_PATTERNS: ClassVar[list[str]] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def matches_model(cls, model: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this adapter handles the given model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this adapter handles the model
|
||||||
|
"""
|
||||||
|
model_lower = model.lower()
|
||||||
|
return any(pattern in model_lower for pattern in cls.MODEL_PATTERNS)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts for the target model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to format
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_type(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
context_type: ContextType,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts of a specific type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts of the same type
|
||||||
|
context_type: The type of contexts
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string for this context type
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_type_order(self) -> list[ContextType]:
|
||||||
|
"""
|
||||||
|
Get the preferred order of context types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of context types in preferred order
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
ContextType.SYSTEM,
|
||||||
|
ContextType.TASK,
|
||||||
|
ContextType.KNOWLEDGE,
|
||||||
|
ContextType.CONVERSATION,
|
||||||
|
ContextType.TOOL,
|
||||||
|
]
|
||||||
|
|
||||||
|
def group_by_type(
|
||||||
|
self, contexts: list[BaseContext]
|
||||||
|
) -> dict[ContextType, list[BaseContext]]:
|
||||||
|
"""
|
||||||
|
Group contexts by their type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to group
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping context type to list of contexts
|
||||||
|
"""
|
||||||
|
by_type: dict[ContextType, list[BaseContext]] = {}
|
||||||
|
for context in contexts:
|
||||||
|
ct = context.get_type()
|
||||||
|
if ct not in by_type:
|
||||||
|
by_type[ct] = []
|
||||||
|
by_type[ct].append(context)
|
||||||
|
return by_type
|
||||||
|
|
||||||
|
def get_separator(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the separator between context sections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Separator string
|
||||||
|
"""
|
||||||
|
return "\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultAdapter(ModelAdapter):
|
||||||
|
"""
|
||||||
|
Default adapter for unknown models.
|
||||||
|
|
||||||
|
Uses simple plain-text formatting with minimal structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODEL_PATTERNS: ClassVar[list[str]] = [] # Fallback adapter
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def matches_model(cls, model: str) -> bool:
|
||||||
|
"""Always returns True as fallback."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Format contexts as plain text."""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
by_type = self.group_by_type(contexts)
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
for ct in self.get_type_order():
|
||||||
|
if ct in by_type:
|
||||||
|
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||||
|
if formatted:
|
||||||
|
parts.append(formatted)
|
||||||
|
|
||||||
|
return self.get_separator().join(parts)
|
||||||
|
|
||||||
|
def format_type(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
context_type: ContextType,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Format contexts of a type as plain text."""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
content = "\n\n".join(c.content for c in contexts)
|
||||||
|
|
||||||
|
if context_type == ContextType.SYSTEM:
|
||||||
|
return content
|
||||||
|
elif context_type == ContextType.TASK:
|
||||||
|
return f"Task:\n{content}"
|
||||||
|
elif context_type == ContextType.KNOWLEDGE:
|
||||||
|
return f"Reference Information:\n{content}"
|
||||||
|
elif context_type == ContextType.CONVERSATION:
|
||||||
|
return f"Previous Conversation:\n{content}"
|
||||||
|
elif context_type == ContextType.TOOL:
|
||||||
|
return f"Tool Results:\n{content}"
|
||||||
|
|
||||||
|
return content
|
||||||
212
backend/app/services/context/adapters/claude.py
Normal file
212
backend/app/services/context/adapters/claude.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""
|
||||||
|
Claude Model Adapter.
|
||||||
|
|
||||||
|
Provides Claude-specific context formatting using XML tags
|
||||||
|
which Claude models understand natively.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
from .base import ModelAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeAdapter(ModelAdapter):
|
||||||
|
"""
|
||||||
|
Claude-specific context formatting adapter.
|
||||||
|
|
||||||
|
Claude models have native understanding of XML structure,
|
||||||
|
so we use XML tags for clear delineation of context types.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- XML tags for each context type
|
||||||
|
- Document structure for knowledge contexts
|
||||||
|
- Role-based message formatting for conversations
|
||||||
|
- Tool result wrapping with tool names
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODEL_PATTERNS: ClassVar[list[str]] = ["claude", "anthropic"]
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts for Claude models.
|
||||||
|
|
||||||
|
Uses XML tags for structured content that Claude
|
||||||
|
understands natively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to format
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
XML-structured context string
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
by_type = self.group_by_type(contexts)
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
for ct in self.get_type_order():
|
||||||
|
if ct in by_type:
|
||||||
|
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||||
|
if formatted:
|
||||||
|
parts.append(formatted)
|
||||||
|
|
||||||
|
return self.get_separator().join(parts)
|
||||||
|
|
||||||
|
def format_type(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
context_type: ContextType,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts of a specific type for Claude.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts of the same type
|
||||||
|
context_type: The type of contexts
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
XML-formatted string for this context type
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if context_type == ContextType.SYSTEM:
|
||||||
|
return self._format_system(contexts)
|
||||||
|
elif context_type == ContextType.TASK:
|
||||||
|
return self._format_task(contexts)
|
||||||
|
elif context_type == ContextType.KNOWLEDGE:
|
||||||
|
return self._format_knowledge(contexts)
|
||||||
|
elif context_type == ContextType.CONVERSATION:
|
||||||
|
return self._format_conversation(contexts)
|
||||||
|
elif context_type == ContextType.TOOL:
|
||||||
|
return self._format_tool(contexts)
|
||||||
|
|
||||||
|
# Fallback for any unhandled context types - still escape content
|
||||||
|
# to prevent XML injection if new types are added without updating adapter
|
||||||
|
return "\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||||
|
|
||||||
|
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""Format system contexts."""
|
||||||
|
# System prompts are typically admin-controlled, but escape for safety
|
||||||
|
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||||
|
return f"<system_instructions>\n{content}\n</system_instructions>"
|
||||||
|
|
||||||
|
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""Format task contexts."""
|
||||||
|
content = "\n\n".join(self._escape_xml_content(c.content) for c in contexts)
|
||||||
|
return f"<current_task>\n{content}\n</current_task>"
|
||||||
|
|
||||||
|
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format knowledge contexts as structured documents.
|
||||||
|
|
||||||
|
Each knowledge context becomes a document with source attribution.
|
||||||
|
All content is XML-escaped to prevent injection attacks.
|
||||||
|
"""
|
||||||
|
parts = ["<reference_documents>"]
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
source = self._escape_xml(ctx.source)
|
||||||
|
# Escape content to prevent XML injection
|
||||||
|
content = self._escape_xml_content(ctx.content)
|
||||||
|
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||||
|
|
||||||
|
if score:
|
||||||
|
# Escape score to prevent XML injection via metadata
|
||||||
|
escaped_score = self._escape_xml(str(score))
|
||||||
|
parts.append(
|
||||||
|
f'<document source="{source}" relevance="{escaped_score}">'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parts.append(f'<document source="{source}">')
|
||||||
|
|
||||||
|
parts.append(content)
|
||||||
|
parts.append("</document>")
|
||||||
|
|
||||||
|
parts.append("</reference_documents>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _format_conversation(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format conversation contexts as message history.
|
||||||
|
|
||||||
|
Uses role-based message tags for clear turn delineation.
|
||||||
|
All content is XML-escaped to prevent prompt injection.
|
||||||
|
"""
|
||||||
|
parts = ["<conversation_history>"]
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
role = self._escape_xml(ctx.metadata.get("role", "user"))
|
||||||
|
# Escape content to prevent prompt injection via fake XML tags
|
||||||
|
content = self._escape_xml_content(ctx.content)
|
||||||
|
parts.append(f'<message role="{role}">')
|
||||||
|
parts.append(content)
|
||||||
|
parts.append("</message>")
|
||||||
|
|
||||||
|
parts.append("</conversation_history>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _format_tool(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format tool contexts as tool results.
|
||||||
|
|
||||||
|
Each tool result is wrapped with the tool name.
|
||||||
|
All content is XML-escaped to prevent injection.
|
||||||
|
"""
|
||||||
|
parts = ["<tool_results>"]
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
tool_name = self._escape_xml(ctx.metadata.get("tool_name", "unknown"))
|
||||||
|
status = ctx.metadata.get("status", "")
|
||||||
|
|
||||||
|
if status:
|
||||||
|
parts.append(
|
||||||
|
f'<tool_result name="{tool_name}" status="{self._escape_xml(status)}">'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parts.append(f'<tool_result name="{tool_name}">')
|
||||||
|
|
||||||
|
# Escape content to prevent injection
|
||||||
|
parts.append(self._escape_xml_content(ctx.content))
|
||||||
|
parts.append("</tool_result>")
|
||||||
|
|
||||||
|
parts.append("</tool_results>")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _escape_xml(text: str) -> str:
|
||||||
|
"""Escape XML special characters in attribute values."""
|
||||||
|
return (
|
||||||
|
text.replace("&", "&")
|
||||||
|
.replace("<", "<")
|
||||||
|
.replace(">", ">")
|
||||||
|
.replace('"', """)
|
||||||
|
.replace("'", "'")
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _escape_xml_content(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Escape XML special characters in element content.
|
||||||
|
|
||||||
|
This prevents XML injection attacks where malicious content
|
||||||
|
could break out of XML tags or inject fake tags for prompt injection.
|
||||||
|
|
||||||
|
Only escapes &, <, > since quotes don't need escaping in content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Content text to escape
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
XML-safe content string
|
||||||
|
"""
|
||||||
|
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||||
160
backend/app/services/context/adapters/openai.py
Normal file
160
backend/app/services/context/adapters/openai.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
"""
|
||||||
|
OpenAI Model Adapter.
|
||||||
|
|
||||||
|
Provides OpenAI-specific context formatting using markdown
|
||||||
|
which GPT models understand well.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
from .base import ModelAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIAdapter(ModelAdapter):
|
||||||
|
"""
|
||||||
|
OpenAI-specific context formatting adapter.
|
||||||
|
|
||||||
|
GPT models work well with markdown formatting,
|
||||||
|
so we use headers and structured markdown for clarity.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Markdown headers for each context type
|
||||||
|
- Bulleted lists for document sources
|
||||||
|
- Bold role labels for conversations
|
||||||
|
- Code blocks for tool outputs
|
||||||
|
"""
|
||||||
|
|
||||||
|
MODEL_PATTERNS: ClassVar[list[str]] = ["gpt", "openai", "o1", "o3"]
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts for OpenAI models.
|
||||||
|
|
||||||
|
Uses markdown formatting for structured content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to format
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown-structured context string
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
by_type = self.group_by_type(contexts)
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
for ct in self.get_type_order():
|
||||||
|
if ct in by_type:
|
||||||
|
formatted = self.format_type(by_type[ct], ct, **kwargs)
|
||||||
|
if formatted:
|
||||||
|
parts.append(formatted)
|
||||||
|
|
||||||
|
return self.get_separator().join(parts)
|
||||||
|
|
||||||
|
def format_type(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
context_type: ContextType,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts of a specific type for OpenAI.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts of the same type
|
||||||
|
context_type: The type of contexts
|
||||||
|
**kwargs: Additional formatting options
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown-formatted string for this context type
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if context_type == ContextType.SYSTEM:
|
||||||
|
return self._format_system(contexts)
|
||||||
|
elif context_type == ContextType.TASK:
|
||||||
|
return self._format_task(contexts)
|
||||||
|
elif context_type == ContextType.KNOWLEDGE:
|
||||||
|
return self._format_knowledge(contexts)
|
||||||
|
elif context_type == ContextType.CONVERSATION:
|
||||||
|
return self._format_conversation(contexts)
|
||||||
|
elif context_type == ContextType.TOOL:
|
||||||
|
return self._format_tool(contexts)
|
||||||
|
|
||||||
|
return "\n".join(c.content for c in contexts)
|
||||||
|
|
||||||
|
def _format_system(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""Format system contexts."""
|
||||||
|
content = "\n\n".join(c.content for c in contexts)
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _format_task(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""Format task contexts."""
|
||||||
|
content = "\n\n".join(c.content for c in contexts)
|
||||||
|
return f"## Current Task\n\n{content}"
|
||||||
|
|
||||||
|
def _format_knowledge(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format knowledge contexts as structured documents.
|
||||||
|
|
||||||
|
Each knowledge context becomes a section with source attribution.
|
||||||
|
"""
|
||||||
|
parts = ["## Reference Documents\n"]
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
source = ctx.source
|
||||||
|
score = ctx.metadata.get("score", ctx.metadata.get("relevance_score", ""))
|
||||||
|
|
||||||
|
if score:
|
||||||
|
parts.append(f"### Source: {source} (relevance: {score})\n")
|
||||||
|
else:
|
||||||
|
parts.append(f"### Source: {source}\n")
|
||||||
|
|
||||||
|
parts.append(ctx.content)
|
||||||
|
parts.append("")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def _format_conversation(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format conversation contexts as message history.
|
||||||
|
|
||||||
|
Uses bold role labels for clear turn delineation.
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
role = ctx.metadata.get("role", "user").upper()
|
||||||
|
parts.append(f"**{role}**: {ctx.content}")
|
||||||
|
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
|
def _format_tool(self, contexts: list[BaseContext]) -> str:
|
||||||
|
"""
|
||||||
|
Format tool contexts as tool results.
|
||||||
|
|
||||||
|
Each tool result is in a code block with the tool name.
|
||||||
|
"""
|
||||||
|
parts = ["## Recent Tool Results\n"]
|
||||||
|
|
||||||
|
for ctx in contexts:
|
||||||
|
tool_name = ctx.metadata.get("tool_name", "unknown")
|
||||||
|
status = ctx.metadata.get("status", "")
|
||||||
|
|
||||||
|
if status:
|
||||||
|
parts.append(f"### Tool: {tool_name} ({status})\n")
|
||||||
|
else:
|
||||||
|
parts.append(f"### Tool: {tool_name}\n")
|
||||||
|
|
||||||
|
parts.append(f"```\n{ctx.content}\n```")
|
||||||
|
parts.append("")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
12
backend/app/services/context/assembly/__init__.py
Normal file
12
backend/app/services/context/assembly/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Context Assembly Module.
|
||||||
|
|
||||||
|
Provides the assembly pipeline and formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .pipeline import ContextPipeline, PipelineMetrics
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextPipeline",
|
||||||
|
"PipelineMetrics",
|
||||||
|
]
|
||||||
362
backend/app/services/context/assembly/pipeline.py
Normal file
362
backend/app/services/context/assembly/pipeline.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
"""
|
||||||
|
Context Assembly Pipeline.
|
||||||
|
|
||||||
|
Orchestrates the full context assembly workflow:
|
||||||
|
Gather → Count → Score → Rank → Compress → Format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..adapters import get_adapter
|
||||||
|
from ..budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||||
|
from ..compression.truncation import ContextCompressor
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import AssemblyTimeoutError
|
||||||
|
from ..prioritization import ContextRanker
|
||||||
|
from ..scoring import CompositeScorer
|
||||||
|
from ..types import AssembledContext, BaseContext, ContextType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineMetrics:
|
||||||
|
"""Metrics from pipeline execution."""
|
||||||
|
|
||||||
|
start_time: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
end_time: datetime | None = None
|
||||||
|
total_contexts: int = 0
|
||||||
|
selected_contexts: int = 0
|
||||||
|
excluded_contexts: int = 0
|
||||||
|
compressed_contexts: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
assembly_time_ms: float = 0.0
|
||||||
|
scoring_time_ms: float = 0.0
|
||||||
|
compression_time_ms: float = 0.0
|
||||||
|
formatting_time_ms: float = 0.0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"start_time": self.start_time.isoformat(),
|
||||||
|
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||||
|
"total_contexts": self.total_contexts,
|
||||||
|
"selected_contexts": self.selected_contexts,
|
||||||
|
"excluded_contexts": self.excluded_contexts,
|
||||||
|
"compressed_contexts": self.compressed_contexts,
|
||||||
|
"total_tokens": self.total_tokens,
|
||||||
|
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||||
|
"scoring_time_ms": round(self.scoring_time_ms, 2),
|
||||||
|
"compression_time_ms": round(self.compression_time_ms, 2),
|
||||||
|
"formatting_time_ms": round(self.formatting_time_ms, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ContextPipeline:
|
||||||
|
"""
|
||||||
|
Context assembly pipeline.
|
||||||
|
|
||||||
|
Orchestrates the full workflow of context assembly:
|
||||||
|
1. Validate and count tokens for all contexts
|
||||||
|
2. Score contexts based on relevance, recency, and priority
|
||||||
|
3. Rank and select contexts within budget
|
||||||
|
4. Compress if needed to fit remaining budget
|
||||||
|
5. Format for the target model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
calculator: TokenCalculator | None = None,
|
||||||
|
scorer: CompositeScorer | None = None,
|
||||||
|
ranker: ContextRanker | None = None,
|
||||||
|
compressor: ContextCompressor | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the context pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager for LLM Gateway integration
|
||||||
|
settings: Context settings
|
||||||
|
calculator: Token calculator
|
||||||
|
scorer: Context scorer
|
||||||
|
ranker: Context ranker
|
||||||
|
compressor: Context compressor
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self._calculator = calculator or TokenCalculator(mcp_manager=mcp_manager)
|
||||||
|
self._scorer = scorer or CompositeScorer(
|
||||||
|
mcp_manager=mcp_manager, settings=self._settings
|
||||||
|
)
|
||||||
|
self._ranker = ranker or ContextRanker(
|
||||||
|
scorer=self._scorer, calculator=self._calculator
|
||||||
|
)
|
||||||
|
self._compressor = compressor or ContextCompressor(calculator=self._calculator)
|
||||||
|
self._allocator = BudgetAllocator(self._settings)
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""Set MCP manager for all components."""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._calculator.set_mcp_manager(mcp_manager)
|
||||||
|
self._scorer.set_mcp_manager(mcp_manager)
|
||||||
|
|
||||||
|
async def assemble(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
custom_budget: TokenBudget | None = None,
|
||||||
|
compress: bool = True,
|
||||||
|
format_output: bool = True,
|
||||||
|
timeout_ms: int | None = None,
|
||||||
|
) -> AssembledContext:
|
||||||
|
"""
|
||||||
|
Assemble context for an LLM request.
|
||||||
|
|
||||||
|
This is the main entry point for context assembly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts to assemble
|
||||||
|
query: Query to optimize for
|
||||||
|
model: Target model name
|
||||||
|
max_tokens: Maximum total tokens (uses model default if None)
|
||||||
|
custom_budget: Optional pre-configured budget
|
||||||
|
compress: Whether to compress oversized contexts
|
||||||
|
format_output: Whether to format the final output
|
||||||
|
timeout_ms: Maximum assembly time in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssembledContext with optimized content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssemblyTimeoutError: If assembly exceeds timeout
|
||||||
|
"""
|
||||||
|
timeout = timeout_ms or self._settings.max_assembly_time_ms
|
||||||
|
start = time.perf_counter()
|
||||||
|
metrics = PipelineMetrics(total_contexts=len(contexts))
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create or use budget
|
||||||
|
if custom_budget:
|
||||||
|
budget = custom_budget
|
||||||
|
elif max_tokens:
|
||||||
|
budget = self._allocator.create_budget(max_tokens)
|
||||||
|
else:
|
||||||
|
budget = self._allocator.create_budget_for_model(model)
|
||||||
|
|
||||||
|
# 1. Count tokens for all contexts (with timeout enforcement)
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._ensure_token_counts(contexts, model),
|
||||||
|
timeout=self._remaining_timeout(start, timeout),
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message="Context assembly timed out during token counting",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check timeout (handles edge case where operation finished just at limit)
|
||||||
|
self._check_timeout(start, timeout, "token counting")
|
||||||
|
|
||||||
|
# 2. Score and rank contexts (with timeout enforcement)
|
||||||
|
scoring_start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
ranking_result = await asyncio.wait_for(
|
||||||
|
self._ranker.rank(
|
||||||
|
contexts=contexts,
|
||||||
|
query=query,
|
||||||
|
budget=budget,
|
||||||
|
model=model,
|
||||||
|
),
|
||||||
|
timeout=self._remaining_timeout(start, timeout),
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message="Context assembly timed out during scoring/ranking",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout,
|
||||||
|
)
|
||||||
|
metrics.scoring_time_ms = (time.perf_counter() - scoring_start) * 1000
|
||||||
|
|
||||||
|
selected_contexts = ranking_result.selected_contexts
|
||||||
|
metrics.selected_contexts = len(selected_contexts)
|
||||||
|
metrics.excluded_contexts = len(ranking_result.excluded)
|
||||||
|
|
||||||
|
# Check timeout
|
||||||
|
self._check_timeout(start, timeout, "scoring")
|
||||||
|
|
||||||
|
# 3. Compress if needed and enabled (with timeout enforcement)
|
||||||
|
if compress and self._needs_compression(selected_contexts, budget):
|
||||||
|
compression_start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
selected_contexts = await asyncio.wait_for(
|
||||||
|
self._compressor.compress_contexts(
|
||||||
|
selected_contexts, budget, model
|
||||||
|
),
|
||||||
|
timeout=self._remaining_timeout(start, timeout),
|
||||||
|
)
|
||||||
|
except TimeoutError:
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message="Context assembly timed out during compression",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout,
|
||||||
|
)
|
||||||
|
metrics.compression_time_ms = (
|
||||||
|
time.perf_counter() - compression_start
|
||||||
|
) * 1000
|
||||||
|
metrics.compressed_contexts = sum(
|
||||||
|
1 for c in selected_contexts if c.metadata.get("truncated", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check timeout
|
||||||
|
self._check_timeout(start, timeout, "compression")
|
||||||
|
|
||||||
|
# 4. Format output
|
||||||
|
formatting_start = time.perf_counter()
|
||||||
|
if format_output:
|
||||||
|
formatted_content = self._format_contexts(selected_contexts, model)
|
||||||
|
else:
|
||||||
|
formatted_content = "\n\n".join(c.content for c in selected_contexts)
|
||||||
|
metrics.formatting_time_ms = (time.perf_counter() - formatting_start) * 1000
|
||||||
|
|
||||||
|
# Calculate final metrics
|
||||||
|
total_tokens = sum(c.token_count or 0 for c in selected_contexts)
|
||||||
|
metrics.total_tokens = total_tokens
|
||||||
|
metrics.assembly_time_ms = (time.perf_counter() - start) * 1000
|
||||||
|
metrics.end_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
return AssembledContext(
|
||||||
|
content=formatted_content,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
context_count=len(selected_contexts),
|
||||||
|
assembly_time_ms=metrics.assembly_time_ms,
|
||||||
|
model=model,
|
||||||
|
contexts=selected_contexts,
|
||||||
|
excluded_count=metrics.excluded_contexts,
|
||||||
|
metadata={
|
||||||
|
"metrics": metrics.to_dict(),
|
||||||
|
"query": query,
|
||||||
|
"budget": budget.to_dict(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except AssemblyTimeoutError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Context assembly failed: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _ensure_token_counts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure all contexts have token counts."""
|
||||||
|
tasks = []
|
||||||
|
for context in contexts:
|
||||||
|
if context.token_count is None:
|
||||||
|
tasks.append(self._count_and_set(context, model))
|
||||||
|
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
async def _count_and_set(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Count tokens and set on context."""
|
||||||
|
count = await self._calculator.count_tokens(context.content, model)
|
||||||
|
context.token_count = count
|
||||||
|
|
||||||
|
def _needs_compression(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
budget: TokenBudget,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if any contexts exceed their type budget."""
|
||||||
|
# Group by type and check totals
|
||||||
|
by_type: dict[ContextType, int] = {}
|
||||||
|
for context in contexts:
|
||||||
|
ct = context.get_type()
|
||||||
|
by_type[ct] = by_type.get(ct, 0) + (context.token_count or 0)
|
||||||
|
|
||||||
|
for ct, total in by_type.items():
|
||||||
|
if total > budget.get_allocation(ct):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Also check if utilization exceeds threshold
|
||||||
|
return budget.utilization() > self._settings.compression_threshold
|
||||||
|
|
||||||
|
def _format_contexts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
model: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Format contexts for the target model.
|
||||||
|
|
||||||
|
Uses model-specific adapters (ClaudeAdapter, OpenAIAdapter, etc.)
|
||||||
|
to format contexts optimally for each model family.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to format
|
||||||
|
model: Target model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted context string
|
||||||
|
"""
|
||||||
|
adapter = get_adapter(model)
|
||||||
|
return adapter.format(contexts)
|
||||||
|
|
||||||
|
def _check_timeout(
|
||||||
|
self,
|
||||||
|
start: float,
|
||||||
|
timeout_ms: int,
|
||||||
|
phase: str,
|
||||||
|
) -> None:
|
||||||
|
"""Check if timeout exceeded and raise if so."""
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
if elapsed_ms >= timeout_ms:
|
||||||
|
raise AssemblyTimeoutError(
|
||||||
|
message=f"Context assembly timed out during {phase}",
|
||||||
|
elapsed_ms=elapsed_ms,
|
||||||
|
timeout_ms=timeout_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _remaining_timeout(self, start: float, timeout_ms: int) -> float:
|
||||||
|
"""
|
||||||
|
Calculate remaining timeout in seconds for asyncio.wait_for.
|
||||||
|
|
||||||
|
Returns at least a small positive value to avoid immediate timeout
|
||||||
|
edge cases with wait_for.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start: Start time from time.perf_counter()
|
||||||
|
timeout_ms: Total timeout in milliseconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remaining timeout in seconds (minimum 0.001)
|
||||||
|
"""
|
||||||
|
elapsed_ms = (time.perf_counter() - start) * 1000
|
||||||
|
remaining_ms = timeout_ms - elapsed_ms
|
||||||
|
# Return at least 1ms to avoid zero/negative timeout edge cases
|
||||||
|
return max(remaining_ms / 1000.0, 0.001)
|
||||||
14
backend/app/services/context/budget/__init__.py
Normal file
14
backend/app/services/context/budget/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Token Budget Management Module.
|
||||||
|
|
||||||
|
Provides token counting and budget allocation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .allocator import BudgetAllocator, TokenBudget
|
||||||
|
from .calculator import TokenCalculator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BudgetAllocator",
|
||||||
|
"TokenBudget",
|
||||||
|
"TokenCalculator",
|
||||||
|
]
|
||||||
433
backend/app/services/context/budget/allocator.py
Normal file
433
backend/app/services/context/budget/allocator.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
"""
|
||||||
|
Token Budget Allocator for Context Management.
|
||||||
|
|
||||||
|
Manages token budget allocation across context types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import BudgetExceededError
|
||||||
|
from ..types import ContextType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenBudget:
|
||||||
|
"""
|
||||||
|
Token budget allocation and tracking.
|
||||||
|
|
||||||
|
Tracks allocated tokens per context type and
|
||||||
|
monitors usage to prevent overflows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Total budget
|
||||||
|
total: int
|
||||||
|
|
||||||
|
# Allocated per type
|
||||||
|
system: int = 0
|
||||||
|
task: int = 0
|
||||||
|
knowledge: int = 0
|
||||||
|
conversation: int = 0
|
||||||
|
tools: int = 0
|
||||||
|
response_reserve: int = 0
|
||||||
|
buffer: int = 0
|
||||||
|
|
||||||
|
# Usage tracking
|
||||||
|
used: dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Initialize usage tracking."""
|
||||||
|
if not self.used:
|
||||||
|
self.used = {ct.value: 0 for ct in ContextType}
|
||||||
|
|
||||||
|
def get_allocation(self, context_type: ContextType | str) -> int:
|
||||||
|
"""
|
||||||
|
Get allocated tokens for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to get allocation for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Allocated token count
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
allocation_map = {
|
||||||
|
"system": self.system,
|
||||||
|
"task": self.task,
|
||||||
|
"knowledge": self.knowledge,
|
||||||
|
"conversation": self.conversation,
|
||||||
|
"tool": self.tools,
|
||||||
|
}
|
||||||
|
return allocation_map.get(context_type, 0)
|
||||||
|
|
||||||
|
def get_used(self, context_type: ContextType | str) -> int:
|
||||||
|
"""
|
||||||
|
Get used tokens for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Used token count
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
return self.used.get(context_type, 0)
|
||||||
|
|
||||||
|
def remaining(self, context_type: ContextType | str) -> int:
|
||||||
|
"""
|
||||||
|
Get remaining tokens for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Remaining token count
|
||||||
|
"""
|
||||||
|
allocated = self.get_allocation(context_type)
|
||||||
|
used = self.get_used(context_type)
|
||||||
|
return max(0, allocated - used)
|
||||||
|
|
||||||
|
def total_remaining(self) -> int:
|
||||||
|
"""
|
||||||
|
Get total remaining tokens across all types.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total remaining tokens
|
||||||
|
"""
|
||||||
|
total_used = sum(self.used.values())
|
||||||
|
usable = self.total - self.response_reserve - self.buffer
|
||||||
|
return max(0, usable - total_used)
|
||||||
|
|
||||||
|
def total_used(self) -> int:
|
||||||
|
"""
|
||||||
|
Get total used tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total used tokens
|
||||||
|
"""
|
||||||
|
return sum(self.used.values())
|
||||||
|
|
||||||
|
def can_fit(self, context_type: ContextType | str, tokens: int) -> bool:
|
||||||
|
"""
|
||||||
|
Check if tokens fit within budget for a type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to check
|
||||||
|
tokens: Number of tokens to fit
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if tokens fit within remaining budget
|
||||||
|
"""
|
||||||
|
return tokens <= self.remaining(context_type)
|
||||||
|
|
||||||
|
def allocate(
|
||||||
|
self,
|
||||||
|
context_type: ContextType | str,
|
||||||
|
tokens: int,
|
||||||
|
force: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Allocate (use) tokens from a context type's budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to allocate from
|
||||||
|
tokens: Number of tokens to allocate
|
||||||
|
force: If True, allow exceeding budget
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if allocation succeeded
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BudgetExceededError: If tokens exceed budget and force=False
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
if not force and not self.can_fit(context_type, tokens):
|
||||||
|
raise BudgetExceededError(
|
||||||
|
message=f"Token budget exceeded for {context_type}",
|
||||||
|
allocated=self.get_allocation(context_type),
|
||||||
|
requested=self.get_used(context_type) + tokens,
|
||||||
|
context_type=context_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.used[context_type] = self.used.get(context_type, 0) + tokens
|
||||||
|
return True
|
||||||
|
|
||||||
|
def deallocate(
|
||||||
|
self,
|
||||||
|
context_type: ContextType | str,
|
||||||
|
tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Deallocate (return) tokens to a context type's budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to return to
|
||||||
|
tokens: Number of tokens to return
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
current = self.used.get(context_type, 0)
|
||||||
|
self.used[context_type] = max(0, current - tokens)
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset all usage tracking."""
|
||||||
|
self.used = {ct.value: 0 for ct in ContextType}
|
||||||
|
|
||||||
|
def utilization(self, context_type: ContextType | str | None = None) -> float:
|
||||||
|
"""
|
||||||
|
Get budget utilization percentage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Specific type or None for total
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Utilization as a fraction (0.0 to 1.0+)
|
||||||
|
"""
|
||||||
|
if context_type is None:
|
||||||
|
usable = self.total - self.response_reserve - self.buffer
|
||||||
|
if usable <= 0:
|
||||||
|
return 0.0
|
||||||
|
return self.total_used() / usable
|
||||||
|
|
||||||
|
allocated = self.get_allocation(context_type)
|
||||||
|
if allocated <= 0:
|
||||||
|
return 0.0
|
||||||
|
return self.get_used(context_type) / allocated
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert budget to dictionary."""
|
||||||
|
return {
|
||||||
|
"total": self.total,
|
||||||
|
"allocations": {
|
||||||
|
"system": self.system,
|
||||||
|
"task": self.task,
|
||||||
|
"knowledge": self.knowledge,
|
||||||
|
"conversation": self.conversation,
|
||||||
|
"tools": self.tools,
|
||||||
|
"response_reserve": self.response_reserve,
|
||||||
|
"buffer": self.buffer,
|
||||||
|
},
|
||||||
|
"used": dict(self.used),
|
||||||
|
"remaining": {ct.value: self.remaining(ct) for ct in ContextType},
|
||||||
|
"total_used": self.total_used(),
|
||||||
|
"total_remaining": self.total_remaining(),
|
||||||
|
"utilization": round(self.utilization(), 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetAllocator:
|
||||||
|
"""
|
||||||
|
Budget allocator for context management.
|
||||||
|
|
||||||
|
Creates token budgets based on configuration and
|
||||||
|
model context window sizes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, settings: ContextSettings | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize budget allocator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
settings: Context settings (uses default if None)
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
|
||||||
|
def create_budget(
|
||||||
|
self,
|
||||||
|
total_tokens: int,
|
||||||
|
custom_allocations: dict[str, float] | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Create a token budget with allocations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_tokens: Total available tokens
|
||||||
|
custom_allocations: Optional custom allocation percentages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBudget with allocations set
|
||||||
|
"""
|
||||||
|
# Use custom or default allocations
|
||||||
|
if custom_allocations:
|
||||||
|
alloc = custom_allocations
|
||||||
|
else:
|
||||||
|
alloc = self._settings.get_budget_allocation()
|
||||||
|
|
||||||
|
return TokenBudget(
|
||||||
|
total=total_tokens,
|
||||||
|
system=int(total_tokens * alloc.get("system", 0.05)),
|
||||||
|
task=int(total_tokens * alloc.get("task", 0.10)),
|
||||||
|
knowledge=int(total_tokens * alloc.get("knowledge", 0.40)),
|
||||||
|
conversation=int(total_tokens * alloc.get("conversation", 0.20)),
|
||||||
|
tools=int(total_tokens * alloc.get("tools", 0.05)),
|
||||||
|
response_reserve=int(total_tokens * alloc.get("response", 0.15)),
|
||||||
|
buffer=int(total_tokens * alloc.get("buffer", 0.05)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def adjust_budget(
|
||||||
|
self,
|
||||||
|
budget: TokenBudget,
|
||||||
|
context_type: ContextType | str,
|
||||||
|
adjustment: int,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Adjust a specific allocation in a budget.
|
||||||
|
|
||||||
|
Takes tokens from buffer and adds to specified type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
budget: Budget to adjust
|
||||||
|
context_type: Type to adjust
|
||||||
|
adjustment: Positive to increase, negative to decrease
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adjusted budget
|
||||||
|
"""
|
||||||
|
if isinstance(context_type, ContextType):
|
||||||
|
context_type = context_type.value
|
||||||
|
|
||||||
|
# Calculate adjustment (limited by buffer for increases, by current allocation for decreases)
|
||||||
|
if adjustment > 0:
|
||||||
|
# Taking from buffer - limited by available buffer
|
||||||
|
actual_adjustment = min(adjustment, budget.buffer)
|
||||||
|
budget.buffer -= actual_adjustment
|
||||||
|
else:
|
||||||
|
# Returning to buffer - limited by current allocation of target type
|
||||||
|
current_allocation = budget.get_allocation(context_type)
|
||||||
|
# Can't return more than current allocation
|
||||||
|
actual_adjustment = max(adjustment, -current_allocation)
|
||||||
|
# Add returned tokens back to buffer (adjustment is negative, so subtract)
|
||||||
|
budget.buffer -= actual_adjustment
|
||||||
|
|
||||||
|
# Apply to target type
|
||||||
|
if context_type == "system":
|
||||||
|
budget.system = max(0, budget.system + actual_adjustment)
|
||||||
|
elif context_type == "task":
|
||||||
|
budget.task = max(0, budget.task + actual_adjustment)
|
||||||
|
elif context_type == "knowledge":
|
||||||
|
budget.knowledge = max(0, budget.knowledge + actual_adjustment)
|
||||||
|
elif context_type == "conversation":
|
||||||
|
budget.conversation = max(0, budget.conversation + actual_adjustment)
|
||||||
|
elif context_type == "tool":
|
||||||
|
budget.tools = max(0, budget.tools + actual_adjustment)
|
||||||
|
|
||||||
|
return budget
|
||||||
|
|
||||||
|
def rebalance_budget(
|
||||||
|
self,
|
||||||
|
budget: TokenBudget,
|
||||||
|
prioritize: list[ContextType] | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Rebalance budget based on actual usage.
|
||||||
|
|
||||||
|
Moves unused allocations to prioritized types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
budget: Budget to rebalance
|
||||||
|
prioritize: Types to prioritize (in order)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rebalanced budget
|
||||||
|
"""
|
||||||
|
if prioritize is None:
|
||||||
|
prioritize = [ContextType.KNOWLEDGE, ContextType.TASK, ContextType.SYSTEM]
|
||||||
|
|
||||||
|
# Calculate unused tokens per type
|
||||||
|
unused: dict[str, int] = {}
|
||||||
|
for ct in ContextType:
|
||||||
|
remaining = budget.remaining(ct)
|
||||||
|
if remaining > 0:
|
||||||
|
unused[ct.value] = remaining
|
||||||
|
|
||||||
|
# Calculate total reclaimable (excluding prioritized types)
|
||||||
|
prioritize_values = {ct.value for ct in prioritize}
|
||||||
|
reclaimable = sum(
|
||||||
|
tokens for ct, tokens in unused.items() if ct not in prioritize_values
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redistribute to prioritized types that are near capacity
|
||||||
|
for ct in prioritize:
|
||||||
|
utilization = budget.utilization(ct)
|
||||||
|
|
||||||
|
if utilization > 0.8: # Near capacity
|
||||||
|
# Give more tokens from reclaimable pool
|
||||||
|
bonus = min(reclaimable, budget.get_allocation(ct) // 2)
|
||||||
|
self.adjust_budget(budget, ct, bonus)
|
||||||
|
reclaimable -= bonus
|
||||||
|
|
||||||
|
if reclaimable <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
return budget
|
||||||
|
|
||||||
|
def get_model_context_size(self, model: str) -> int:
|
||||||
|
"""
|
||||||
|
Get context window size for a model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Context window size in tokens
|
||||||
|
"""
|
||||||
|
# Common model context sizes
|
||||||
|
context_sizes = {
|
||||||
|
"claude-3-opus": 200000,
|
||||||
|
"claude-3-sonnet": 200000,
|
||||||
|
"claude-3-haiku": 200000,
|
||||||
|
"claude-3-5-sonnet": 200000,
|
||||||
|
"claude-3-5-haiku": 200000,
|
||||||
|
"claude-opus-4": 200000,
|
||||||
|
"gpt-4-turbo": 128000,
|
||||||
|
"gpt-4": 8192,
|
||||||
|
"gpt-4-32k": 32768,
|
||||||
|
"gpt-4o": 128000,
|
||||||
|
"gpt-4o-mini": 128000,
|
||||||
|
"gpt-3.5-turbo": 16385,
|
||||||
|
"gemini-1.5-pro": 2000000,
|
||||||
|
"gemini-1.5-flash": 1000000,
|
||||||
|
"gemini-2.0-flash": 1000000,
|
||||||
|
"qwen-plus": 32000,
|
||||||
|
"qwen-turbo": 8000,
|
||||||
|
"deepseek-chat": 64000,
|
||||||
|
"deepseek-reasoner": 64000,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check exact match first
|
||||||
|
model_lower = model.lower()
|
||||||
|
if model_lower in context_sizes:
|
||||||
|
return context_sizes[model_lower]
|
||||||
|
|
||||||
|
# Check prefix match
|
||||||
|
for model_name, size in context_sizes.items():
|
||||||
|
if model_lower.startswith(model_name):
|
||||||
|
return size
|
||||||
|
|
||||||
|
# Default fallback
|
||||||
|
return 8192
|
||||||
|
|
||||||
|
def create_budget_for_model(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
custom_allocations: dict[str, float] | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Create a budget based on model's context window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
custom_allocations: Optional custom allocation percentages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBudget sized for the model
|
||||||
|
"""
|
||||||
|
context_size = self.get_model_context_size(model)
|
||||||
|
return self.create_budget(context_size, custom_allocations)
|
||||||
285
backend/app/services/context/budget/calculator.py
Normal file
285
backend/app/services/context/budget/calculator.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""
|
||||||
|
Token Calculator for Context Management.
|
||||||
|
|
||||||
|
Provides token counting with caching and fallback estimation.
|
||||||
|
Integrates with LLM Gateway for accurate counts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCounterProtocol(Protocol):
|
||||||
|
"""Protocol for token counting implementations."""
|
||||||
|
|
||||||
|
async def count_tokens(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count tokens in text."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCalculator:
|
||||||
|
"""
|
||||||
|
Token calculator with LLM Gateway integration.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- In-memory caching for repeated text
|
||||||
|
- Fallback to character-based estimation
|
||||||
|
- Model-specific counting when possible
|
||||||
|
|
||||||
|
The calculator uses the LLM Gateway's count_tokens tool
|
||||||
|
for accurate counting, with a local cache to avoid
|
||||||
|
repeated calls for the same content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default characters per token ratio for estimation
|
||||||
|
DEFAULT_CHARS_PER_TOKEN: ClassVar[float] = 4.0
|
||||||
|
|
||||||
|
# Model-specific ratios (more accurate estimation)
|
||||||
|
MODEL_CHAR_RATIOS: ClassVar[dict[str, float]] = {
|
||||||
|
"claude": 3.5,
|
||||||
|
"gpt-4": 4.0,
|
||||||
|
"gpt-3.5": 4.0,
|
||||||
|
"gemini": 4.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
project_id: str = "system",
|
||||||
|
agent_id: str = "context-engine",
|
||||||
|
cache_enabled: bool = True,
|
||||||
|
cache_max_size: int = 10000,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize token calculator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager for LLM Gateway calls
|
||||||
|
project_id: Project ID for LLM Gateway calls
|
||||||
|
agent_id: Agent ID for LLM Gateway calls
|
||||||
|
cache_enabled: Whether to enable in-memory caching
|
||||||
|
cache_max_size: Maximum cache entries
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._project_id = project_id
|
||||||
|
self._agent_id = agent_id
|
||||||
|
self._cache_enabled = cache_enabled
|
||||||
|
self._cache_max_size = cache_max_size
|
||||||
|
|
||||||
|
# In-memory cache: hash(model:text) -> token_count
|
||||||
|
self._cache: dict[str, int] = {}
|
||||||
|
self._cache_hits = 0
|
||||||
|
self._cache_misses = 0
|
||||||
|
|
||||||
|
def _get_cache_key(self, text: str, model: str | None) -> str:
|
||||||
|
"""Generate cache key from text and model."""
|
||||||
|
# Use hash for efficient storage
|
||||||
|
content = f"{model or 'default'}:{text}"
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
def _check_cache(self, cache_key: str) -> int | None:
|
||||||
|
"""Check cache for existing count."""
|
||||||
|
if not self._cache_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cache_key in self._cache:
|
||||||
|
self._cache_hits += 1
|
||||||
|
return self._cache[cache_key]
|
||||||
|
|
||||||
|
self._cache_misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _store_cache(self, cache_key: str, count: int) -> None:
|
||||||
|
"""Store count in cache."""
|
||||||
|
if not self._cache_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Simple LRU-like eviction: remove oldest entries when full
|
||||||
|
if len(self._cache) >= self._cache_max_size:
|
||||||
|
# Remove first 10% of entries
|
||||||
|
entries_to_remove = self._cache_max_size // 10
|
||||||
|
keys_to_remove = list(self._cache.keys())[:entries_to_remove]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._cache[key]
|
||||||
|
|
||||||
|
self._cache[cache_key] = count
|
||||||
|
|
||||||
|
def estimate_tokens(self, text: str, model: str | None = None) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count based on character count.
|
||||||
|
|
||||||
|
This is a fast fallback when LLM Gateway is unavailable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count
|
||||||
|
model: Optional model for more accurate ratio
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Get model-specific ratio
|
||||||
|
ratio = self.DEFAULT_CHARS_PER_TOKEN
|
||||||
|
if model:
|
||||||
|
model_lower = model.lower()
|
||||||
|
for model_prefix, model_ratio in self.MODEL_CHAR_RATIOS.items():
|
||||||
|
if model_prefix in model_lower:
|
||||||
|
ratio = model_ratio
|
||||||
|
break
|
||||||
|
|
||||||
|
return max(1, int(len(text) / ratio))
|
||||||
|
|
||||||
|
async def count_tokens(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count tokens in text.
|
||||||
|
|
||||||
|
Uses LLM Gateway for accurate counts with fallback to estimation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to count
|
||||||
|
model: Optional model for accurate counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token count
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
cache_key = self._get_cache_key(text, model)
|
||||||
|
cached = self._check_cache(cache_key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Try LLM Gateway
|
||||||
|
if self._mcp is not None:
|
||||||
|
try:
|
||||||
|
result = await self._mcp.call_tool(
|
||||||
|
server="llm-gateway",
|
||||||
|
tool="count_tokens",
|
||||||
|
args={
|
||||||
|
"project_id": self._project_id,
|
||||||
|
"agent_id": self._agent_id,
|
||||||
|
"text": text,
|
||||||
|
"model": model,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse result
|
||||||
|
if result.success and result.data:
|
||||||
|
count = self._parse_token_count(result.data)
|
||||||
|
if count is not None:
|
||||||
|
self._store_cache(cache_key, count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM Gateway token count failed, using estimation: {e}")
|
||||||
|
|
||||||
|
# Fallback to estimation
|
||||||
|
count = self.estimate_tokens(text, model)
|
||||||
|
self._store_cache(cache_key, count)
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _parse_token_count(self, data: Any) -> int | None:
|
||||||
|
"""Parse token count from LLM Gateway response."""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if "token_count" in data:
|
||||||
|
return int(data["token_count"])
|
||||||
|
if "tokens" in data:
|
||||||
|
return int(data["tokens"])
|
||||||
|
if "count" in data:
|
||||||
|
return int(data["count"])
|
||||||
|
|
||||||
|
if isinstance(data, int):
|
||||||
|
return data
|
||||||
|
|
||||||
|
if isinstance(data, str):
|
||||||
|
# Try to parse from text content
|
||||||
|
try:
|
||||||
|
# Handle {"token_count": 123} or just "123"
|
||||||
|
import json
|
||||||
|
|
||||||
|
parsed = json.loads(data)
|
||||||
|
if isinstance(parsed, dict) and "token_count" in parsed:
|
||||||
|
return int(parsed["token_count"])
|
||||||
|
if isinstance(parsed, int):
|
||||||
|
return parsed
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
# Try direct int conversion
|
||||||
|
try:
|
||||||
|
return int(data)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def count_tokens_batch(
|
||||||
|
self,
|
||||||
|
texts: list[str],
|
||||||
|
model: str | None = None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""
|
||||||
|
Count tokens for multiple texts.
|
||||||
|
|
||||||
|
Efficient batch counting with caching and parallel execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: List of texts to count
|
||||||
|
model: Optional model for accurate counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of token counts (same order as input)
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Execute all token counts in parallel for better performance
|
||||||
|
tasks = [self.count_tokens(text, model) for text in texts]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the token count cache."""
|
||||||
|
self._cache.clear()
|
||||||
|
self._cache_hits = 0
|
||||||
|
self._cache_misses = 0
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> dict[str, Any]:
|
||||||
|
"""Get cache statistics."""
|
||||||
|
total = self._cache_hits + self._cache_misses
|
||||||
|
hit_rate = self._cache_hits / total if total > 0 else 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"enabled": self._cache_enabled,
|
||||||
|
"size": len(self._cache),
|
||||||
|
"max_size": self._cache_max_size,
|
||||||
|
"hits": self._cache_hits,
|
||||||
|
"misses": self._cache_misses,
|
||||||
|
"hit_rate": round(hit_rate, 3),
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""
|
||||||
|
Set the MCP manager (for lazy initialization).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager instance
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
11
backend/app/services/context/cache/__init__.py
vendored
Normal file
11
backend/app/services/context/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
Context Cache Module.
|
||||||
|
|
||||||
|
Provides Redis-based caching for assembled contexts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .context_cache import ContextCache
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextCache",
|
||||||
|
]
|
||||||
434
backend/app/services/context/cache/context_cache.py
vendored
Normal file
434
backend/app/services/context/cache/context_cache.py
vendored
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
"""
|
||||||
|
Context Cache Implementation.
|
||||||
|
|
||||||
|
Provides Redis-based caching for context operations including
|
||||||
|
assembled contexts, token counts, and scoring results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import CacheError
|
||||||
|
from ..types import AssembledContext, BaseContext
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextCache:
|
||||||
|
"""
|
||||||
|
Redis-based caching for context operations.
|
||||||
|
|
||||||
|
Provides caching for:
|
||||||
|
- Assembled contexts (fingerprint-based)
|
||||||
|
- Token counts (content hash-based)
|
||||||
|
- Scoring results (context + query hash-based)
|
||||||
|
|
||||||
|
Cache keys use a hierarchical structure:
|
||||||
|
- ctx:assembled:{fingerprint}
|
||||||
|
- ctx:tokens:{model}:{content_hash}
|
||||||
|
- ctx:score:{scorer}:{context_hash}:{query_hash}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the context cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis connection (optional for testing)
|
||||||
|
settings: Cache settings
|
||||||
|
"""
|
||||||
|
self._redis = redis
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._prefix = self._settings.cache_prefix
|
||||||
|
self._ttl = self._settings.cache_ttl_seconds
|
||||||
|
|
||||||
|
# In-memory fallback cache when Redis unavailable
|
||||||
|
self._memory_cache: dict[str, tuple[str, float]] = {}
|
||||||
|
self._max_memory_items = self._settings.cache_memory_max_items
|
||||||
|
|
||||||
|
def set_redis(self, redis: "Redis") -> None:
|
||||||
|
"""Set Redis connection."""
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_enabled(self) -> bool:
|
||||||
|
"""Check if caching is enabled and available."""
|
||||||
|
return self._settings.cache_enabled and self._redis is not None
|
||||||
|
|
||||||
|
def _cache_key(self, *parts: str) -> str:
|
||||||
|
"""
|
||||||
|
Build a cache key from parts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
*parts: Key components
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Colon-separated cache key
|
||||||
|
"""
|
||||||
|
return f"{self._prefix}:{':'.join(parts)}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _hash_content(content: str) -> str:
|
||||||
|
"""
|
||||||
|
Compute hash of content for cache key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to hash
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
32-character hex hash
|
||||||
|
"""
|
||||||
|
return hashlib.sha256(content.encode()).hexdigest()[:32]
|
||||||
|
|
||||||
|
def compute_fingerprint(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
model: str,
|
||||||
|
project_id: str | None = None,
|
||||||
|
agent_id: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Compute a fingerprint for a context assembly request.
|
||||||
|
|
||||||
|
The fingerprint is based on:
|
||||||
|
- Project and agent IDs (for tenant isolation)
|
||||||
|
- Context content hash and metadata (not full content for performance)
|
||||||
|
- Query string
|
||||||
|
- Target model
|
||||||
|
|
||||||
|
SECURITY: project_id and agent_id MUST be included to prevent
|
||||||
|
cross-tenant cache pollution. Without these, one tenant could
|
||||||
|
receive cached contexts from another tenant with the same query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: List of contexts
|
||||||
|
query: Query string
|
||||||
|
model: Model name
|
||||||
|
project_id: Project ID for tenant isolation
|
||||||
|
agent_id: Agent ID for tenant isolation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
32-character hex fingerprint
|
||||||
|
"""
|
||||||
|
# Build a deterministic representation using content hashes for performance
|
||||||
|
# This avoids JSON serializing potentially large content strings
|
||||||
|
context_data = []
|
||||||
|
for ctx in contexts:
|
||||||
|
context_data.append(
|
||||||
|
{
|
||||||
|
"type": ctx.get_type().value,
|
||||||
|
"content_hash": self._hash_content(
|
||||||
|
ctx.content
|
||||||
|
), # Hash instead of full content
|
||||||
|
"source": ctx.source,
|
||||||
|
"priority": ctx.priority, # Already an int
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
# CRITICAL: Include tenant identifiers for cache isolation
|
||||||
|
"project_id": project_id or "",
|
||||||
|
"agent_id": agent_id or "",
|
||||||
|
"contexts": context_data,
|
||||||
|
"query": query,
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
|
||||||
|
content = json.dumps(data, sort_keys=True)
|
||||||
|
return self._hash_content(content)
|
||||||
|
|
||||||
|
async def get_assembled(
|
||||||
|
self,
|
||||||
|
fingerprint: str,
|
||||||
|
) -> AssembledContext | None:
|
||||||
|
"""
|
||||||
|
Get cached assembled context by fingerprint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fingerprint: Assembly fingerprint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached AssembledContext or None if not found
|
||||||
|
"""
|
||||||
|
if not self.is_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
key = self._cache_key("assembled", fingerprint)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._redis.get(key) # type: ignore
|
||||||
|
if data:
|
||||||
|
logger.debug(f"Cache hit for assembled context: {fingerprint}")
|
||||||
|
result = AssembledContext.from_json(data)
|
||||||
|
result.cache_hit = True
|
||||||
|
result.cache_key = fingerprint
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache get error: {e}")
|
||||||
|
raise CacheError(f"Failed to get assembled context: {e}") from e
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_assembled(
|
||||||
|
self,
|
||||||
|
fingerprint: str,
|
||||||
|
context: AssembledContext,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache an assembled context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fingerprint: Assembly fingerprint
|
||||||
|
context: Assembled context to cache
|
||||||
|
ttl: Optional TTL override in seconds
|
||||||
|
"""
|
||||||
|
if not self.is_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
key = self._cache_key("assembled", fingerprint)
|
||||||
|
expire = ttl or self._ttl
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.setex(key, expire, context.to_json()) # type: ignore
|
||||||
|
logger.debug(f"Cached assembled context: {fingerprint}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache set error: {e}")
|
||||||
|
raise CacheError(f"Failed to cache assembled context: {e}") from e
|
||||||
|
|
||||||
|
async def get_token_count(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int | None:
|
||||||
|
"""
|
||||||
|
Get cached token count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to look up
|
||||||
|
model: Model name for model-specific tokenization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached token count or None if not found
|
||||||
|
"""
|
||||||
|
model_key = model or "default"
|
||||||
|
content_hash = self._hash_content(content)
|
||||||
|
key = self._cache_key("tokens", model_key, content_hash)
|
||||||
|
|
||||||
|
# Try in-memory first
|
||||||
|
if key in self._memory_cache:
|
||||||
|
return int(self._memory_cache[key][0])
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._redis.get(key) # type: ignore
|
||||||
|
if data:
|
||||||
|
count = int(data)
|
||||||
|
# Store in memory for faster subsequent access
|
||||||
|
self._set_memory(key, str(count))
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache get error for tokens: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_token_count(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
count: int,
|
||||||
|
model: str | None = None,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache a token count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content that was counted
|
||||||
|
count: Token count
|
||||||
|
model: Model name
|
||||||
|
ttl: Optional TTL override in seconds
|
||||||
|
"""
|
||||||
|
model_key = model or "default"
|
||||||
|
content_hash = self._hash_content(content)
|
||||||
|
key = self._cache_key("tokens", model_key, content_hash)
|
||||||
|
expire = ttl or self._ttl
|
||||||
|
|
||||||
|
# Always store in memory
|
||||||
|
self._set_memory(key, str(count))
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.setex(key, expire, str(count)) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache set error for tokens: {e}")
|
||||||
|
|
||||||
|
async def get_score(
|
||||||
|
self,
|
||||||
|
scorer_name: str,
|
||||||
|
context_id: str,
|
||||||
|
query: str,
|
||||||
|
) -> float | None:
|
||||||
|
"""
|
||||||
|
Get cached score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorer_name: Name of the scorer
|
||||||
|
context_id: Context identifier
|
||||||
|
query: Query string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cached score or None if not found
|
||||||
|
"""
|
||||||
|
query_hash = self._hash_content(query)[:16]
|
||||||
|
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||||
|
|
||||||
|
# Try in-memory first
|
||||||
|
if key in self._memory_cache:
|
||||||
|
return float(self._memory_cache[key][0])
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._redis.get(key) # type: ignore
|
||||||
|
if data:
|
||||||
|
score = float(data)
|
||||||
|
self._set_memory(key, str(score))
|
||||||
|
return score
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache get error for score: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set_score(
|
||||||
|
self,
|
||||||
|
scorer_name: str,
|
||||||
|
context_id: str,
|
||||||
|
query: str,
|
||||||
|
score: float,
|
||||||
|
ttl: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache a score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorer_name: Name of the scorer
|
||||||
|
context_id: Context identifier
|
||||||
|
query: Query string
|
||||||
|
score: Score value
|
||||||
|
ttl: Optional TTL override in seconds
|
||||||
|
"""
|
||||||
|
query_hash = self._hash_content(query)[:16]
|
||||||
|
key = self._cache_key("score", scorer_name, context_id, query_hash)
|
||||||
|
expire = ttl or self._ttl
|
||||||
|
|
||||||
|
# Always store in memory
|
||||||
|
self._set_memory(key, str(score))
|
||||||
|
|
||||||
|
if not self.is_enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.setex(key, expire, str(score)) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache set error for score: {e}")
|
||||||
|
|
||||||
|
async def invalidate(self, pattern: str) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate cache entries matching a pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: Key pattern (supports * wildcard)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys deleted
|
||||||
|
"""
|
||||||
|
if not self.is_enabled:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
full_pattern = self._cache_key(pattern)
|
||||||
|
deleted = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for key in self._redis.scan_iter(match=full_pattern): # type: ignore
|
||||||
|
await self._redis.delete(key) # type: ignore
|
||||||
|
deleted += 1
|
||||||
|
|
||||||
|
logger.info(f"Invalidated {deleted} cache entries matching {pattern}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache invalidation error: {e}")
|
||||||
|
raise CacheError(f"Failed to invalidate cache: {e}") from e
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
async def clear_all(self) -> int:
|
||||||
|
"""
|
||||||
|
Clear all context cache entries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of keys deleted
|
||||||
|
"""
|
||||||
|
self._memory_cache.clear()
|
||||||
|
return await self.invalidate("*")
|
||||||
|
|
||||||
|
def _set_memory(self, key: str, value: str) -> None:
|
||||||
|
"""
|
||||||
|
Set a value in the memory cache.
|
||||||
|
|
||||||
|
Uses LRU-style eviction when max items reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Value to store
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
if len(self._memory_cache) >= self._max_memory_items:
|
||||||
|
# Evict oldest entries
|
||||||
|
sorted_keys = sorted(
|
||||||
|
self._memory_cache.keys(),
|
||||||
|
key=lambda k: self._memory_cache[k][1],
|
||||||
|
)
|
||||||
|
for k in sorted_keys[: len(sorted_keys) // 2]:
|
||||||
|
del self._memory_cache[k]
|
||||||
|
|
||||||
|
self._memory_cache[key] = (value, time.time())
|
||||||
|
|
||||||
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cache stats
|
||||||
|
"""
|
||||||
|
stats = {
|
||||||
|
"enabled": self._settings.cache_enabled,
|
||||||
|
"redis_available": self._redis is not None,
|
||||||
|
"memory_items": len(self._memory_cache),
|
||||||
|
"ttl_seconds": self._ttl,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.is_enabled:
|
||||||
|
try:
|
||||||
|
# Get Redis info
|
||||||
|
info = await self._redis.info("memory") # type: ignore
|
||||||
|
stats["redis_memory_used"] = info.get("used_memory_human", "unknown")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Failed to get Redis stats: {e}")
|
||||||
|
|
||||||
|
return stats
|
||||||
13
backend/app/services/context/compression/__init__.py
Normal file
13
backend/app/services/context/compression/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""
|
||||||
|
Context Compression Module.
|
||||||
|
|
||||||
|
Provides truncation and compression strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .truncation import ContextCompressor, TruncationResult, TruncationStrategy
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextCompressor",
|
||||||
|
"TruncationResult",
|
||||||
|
"TruncationStrategy",
|
||||||
|
]
|
||||||
453
backend/app/services/context/compression/truncation.py
Normal file
453
backend/app/services/context/compression/truncation.py
Normal file
@@ -0,0 +1,453 @@
|
|||||||
|
"""
|
||||||
|
Smart Truncation for Context Compression.
|
||||||
|
|
||||||
|
Provides intelligent truncation strategies to reduce context size
|
||||||
|
while preserving the most important information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..budget import TokenBudget, TokenCalculator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_tokens(text: str, model: str | None = None) -> int:
|
||||||
|
"""
|
||||||
|
Estimate token count using model-specific character ratios.
|
||||||
|
|
||||||
|
Module-level function for reuse across classes. Uses the same ratios
|
||||||
|
as TokenCalculator for consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to estimate tokens for
|
||||||
|
model: Optional model name for model-specific ratios
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count (minimum 1)
|
||||||
|
"""
|
||||||
|
# Model-specific character ratios (chars per token)
|
||||||
|
model_ratios = {
|
||||||
|
"claude": 3.5,
|
||||||
|
"gpt-4": 4.0,
|
||||||
|
"gpt-3.5": 4.0,
|
||||||
|
"gemini": 4.0,
|
||||||
|
}
|
||||||
|
default_ratio = 4.0
|
||||||
|
|
||||||
|
ratio = default_ratio
|
||||||
|
if model:
|
||||||
|
model_lower = model.lower()
|
||||||
|
for model_prefix, model_ratio in model_ratios.items():
|
||||||
|
if model_prefix in model_lower:
|
||||||
|
ratio = model_ratio
|
||||||
|
break
|
||||||
|
|
||||||
|
return max(1, int(len(text) / ratio))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TruncationResult:
|
||||||
|
"""Result of truncation operation."""
|
||||||
|
|
||||||
|
original_tokens: int
|
||||||
|
truncated_tokens: int
|
||||||
|
content: str
|
||||||
|
truncated: bool
|
||||||
|
truncation_ratio: float # 0.0 = no truncation, 1.0 = completely removed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tokens_saved(self) -> int:
|
||||||
|
"""Calculate tokens saved by truncation."""
|
||||||
|
return self.original_tokens - self.truncated_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class TruncationStrategy:
|
||||||
|
"""
|
||||||
|
Smart truncation strategies for context compression.
|
||||||
|
|
||||||
|
Strategies:
|
||||||
|
1. End truncation: Cut from end (for knowledge/docs)
|
||||||
|
2. Middle truncation: Keep start and end (for code)
|
||||||
|
3. Sentence-aware: Truncate at sentence boundaries
|
||||||
|
4. Semantic chunking: Keep most relevant chunks
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
calculator: "TokenCalculator | None" = None,
|
||||||
|
preserve_ratio_start: float | None = None,
|
||||||
|
min_content_length: int | None = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize truncation strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
calculator: Token calculator for accurate counting
|
||||||
|
preserve_ratio_start: Ratio of content to keep from start (overrides settings)
|
||||||
|
min_content_length: Minimum characters to preserve (overrides settings)
|
||||||
|
settings: Context settings (uses global if None)
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
# Use provided values or fall back to settings
|
||||||
|
self._preserve_ratio_start = (
|
||||||
|
preserve_ratio_start
|
||||||
|
if preserve_ratio_start is not None
|
||||||
|
else self._settings.truncation_preserve_ratio
|
||||||
|
)
|
||||||
|
self._min_content_length = (
|
||||||
|
min_content_length
|
||||||
|
if min_content_length is not None
|
||||||
|
else self._settings.truncation_min_content_length
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def truncation_marker(self) -> str:
|
||||||
|
"""Get truncation marker from settings."""
|
||||||
|
return self._settings.truncation_marker
|
||||||
|
|
||||||
|
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||||
|
"""Set token calculator."""
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
async def truncate_to_tokens(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
strategy: str = "end",
|
||||||
|
model: str | None = None,
|
||||||
|
) -> TruncationResult:
|
||||||
|
"""
|
||||||
|
Truncate content to fit within token limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to truncate
|
||||||
|
max_tokens: Maximum tokens allowed
|
||||||
|
strategy: Truncation strategy ('end', 'middle', 'sentence')
|
||||||
|
model: Model for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TruncationResult with truncated content
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return TruncationResult(
|
||||||
|
original_tokens=0,
|
||||||
|
truncated_tokens=0,
|
||||||
|
content="",
|
||||||
|
truncated=False,
|
||||||
|
truncation_ratio=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get original token count
|
||||||
|
original_tokens = await self._count_tokens(content, model)
|
||||||
|
|
||||||
|
if original_tokens <= max_tokens:
|
||||||
|
return TruncationResult(
|
||||||
|
original_tokens=original_tokens,
|
||||||
|
truncated_tokens=original_tokens,
|
||||||
|
content=content,
|
||||||
|
truncated=False,
|
||||||
|
truncation_ratio=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply truncation strategy
|
||||||
|
if strategy == "middle":
|
||||||
|
truncated = await self._truncate_middle(content, max_tokens, model)
|
||||||
|
elif strategy == "sentence":
|
||||||
|
truncated = await self._truncate_sentence(content, max_tokens, model)
|
||||||
|
else: # "end"
|
||||||
|
truncated = await self._truncate_end(content, max_tokens, model)
|
||||||
|
|
||||||
|
truncated_tokens = await self._count_tokens(truncated, model)
|
||||||
|
|
||||||
|
return TruncationResult(
|
||||||
|
original_tokens=original_tokens,
|
||||||
|
truncated_tokens=truncated_tokens,
|
||||||
|
content=truncated,
|
||||||
|
truncated=True,
|
||||||
|
truncation_ratio=0.0
|
||||||
|
if original_tokens == 0
|
||||||
|
else 1 - (truncated_tokens / original_tokens),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _truncate_end(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Truncate from end of content.
|
||||||
|
|
||||||
|
Simple but effective for most content types.
|
||||||
|
"""
|
||||||
|
# Binary search for optimal truncation point
|
||||||
|
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||||
|
available_tokens = max(0, max_tokens - marker_tokens)
|
||||||
|
|
||||||
|
# Edge case: if no tokens available for content, return just the marker
|
||||||
|
if available_tokens <= 0:
|
||||||
|
return self.truncation_marker
|
||||||
|
|
||||||
|
# Estimate characters per token (guard against division by zero)
|
||||||
|
content_tokens = await self._count_tokens(content, model)
|
||||||
|
if content_tokens == 0:
|
||||||
|
return content + self.truncation_marker
|
||||||
|
chars_per_token = len(content) / content_tokens
|
||||||
|
|
||||||
|
# Start with estimated position
|
||||||
|
estimated_chars = int(available_tokens * chars_per_token)
|
||||||
|
truncated = content[:estimated_chars]
|
||||||
|
|
||||||
|
# Refine with binary search
|
||||||
|
low, high = len(truncated) // 2, len(truncated)
|
||||||
|
best = truncated
|
||||||
|
|
||||||
|
for _ in range(5): # Max 5 iterations
|
||||||
|
mid = (low + high) // 2
|
||||||
|
candidate = content[:mid]
|
||||||
|
tokens = await self._count_tokens(candidate, model)
|
||||||
|
|
||||||
|
if tokens <= available_tokens:
|
||||||
|
best = candidate
|
||||||
|
low = mid + 1
|
||||||
|
else:
|
||||||
|
high = mid - 1
|
||||||
|
|
||||||
|
return best + self.truncation_marker
|
||||||
|
|
||||||
|
async def _truncate_middle(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Truncate from middle, keeping start and end.
|
||||||
|
|
||||||
|
Good for code or content where context at boundaries matters.
|
||||||
|
"""
|
||||||
|
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||||
|
available_tokens = max_tokens - marker_tokens
|
||||||
|
|
||||||
|
# Split between start and end
|
||||||
|
start_tokens = int(available_tokens * self._preserve_ratio_start)
|
||||||
|
end_tokens = available_tokens - start_tokens
|
||||||
|
|
||||||
|
# Get start portion
|
||||||
|
start_content = await self._get_content_for_tokens(
|
||||||
|
content, start_tokens, from_start=True, model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get end portion
|
||||||
|
end_content = await self._get_content_for_tokens(
|
||||||
|
content, end_tokens, from_start=False, model=model
|
||||||
|
)
|
||||||
|
|
||||||
|
return start_content + self.truncation_marker + end_content
|
||||||
|
|
||||||
|
async def _truncate_sentence(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Truncate at sentence boundaries.
|
||||||
|
|
||||||
|
Produces cleaner output by not cutting mid-sentence.
|
||||||
|
"""
|
||||||
|
# Split into sentences
|
||||||
|
sentences = re.split(r"(?<=[.!?])\s+", content)
|
||||||
|
|
||||||
|
result: list[str] = []
|
||||||
|
total_tokens = 0
|
||||||
|
marker_tokens = await self._count_tokens(self.truncation_marker, model)
|
||||||
|
available = max_tokens - marker_tokens
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_tokens = await self._count_tokens(sentence, model)
|
||||||
|
if total_tokens + sentence_tokens <= available:
|
||||||
|
result.append(sentence)
|
||||||
|
total_tokens += sentence_tokens
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if len(result) < len(sentences):
|
||||||
|
return " ".join(result) + self.truncation_marker
|
||||||
|
return " ".join(result)
|
||||||
|
|
||||||
|
async def _get_content_for_tokens(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
target_tokens: int,
|
||||||
|
from_start: bool = True,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Get portion of content fitting within token limit."""
|
||||||
|
if target_tokens <= 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
current_tokens = await self._count_tokens(content, model)
|
||||||
|
if current_tokens <= target_tokens:
|
||||||
|
return content
|
||||||
|
|
||||||
|
# Estimate characters (guard against division by zero)
|
||||||
|
if current_tokens == 0:
|
||||||
|
return content
|
||||||
|
chars_per_token = len(content) / current_tokens
|
||||||
|
estimated_chars = int(target_tokens * chars_per_token)
|
||||||
|
|
||||||
|
if from_start:
|
||||||
|
return content[:estimated_chars]
|
||||||
|
else:
|
||||||
|
return content[-estimated_chars:]
|
||||||
|
|
||||||
|
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||||
|
"""Count tokens using calculator or estimation."""
|
||||||
|
if self._calculator is not None:
|
||||||
|
return await self._calculator.count_tokens(text, model)
|
||||||
|
|
||||||
|
# Fallback estimation with model-specific ratios
|
||||||
|
return _estimate_tokens(text, model)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextCompressor:
|
||||||
|
"""
|
||||||
|
Compresses contexts to fit within budget constraints.
|
||||||
|
|
||||||
|
Uses truncation strategies to reduce context size while
|
||||||
|
preserving the most important information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
truncation: TruncationStrategy | None = None,
|
||||||
|
calculator: "TokenCalculator | None" = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize context compressor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
truncation: Truncation strategy to use
|
||||||
|
calculator: Token calculator for counting
|
||||||
|
"""
|
||||||
|
self._truncation = truncation or TruncationStrategy(calculator)
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
if calculator:
|
||||||
|
self._truncation.set_calculator(calculator)
|
||||||
|
|
||||||
|
def set_calculator(self, calculator: "TokenCalculator") -> None:
|
||||||
|
"""Set token calculator."""
|
||||||
|
self._calculator = calculator
|
||||||
|
self._truncation.set_calculator(calculator)
|
||||||
|
|
||||||
|
async def compress_context(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> BaseContext:
|
||||||
|
"""
|
||||||
|
Compress a single context to fit token limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to compress
|
||||||
|
max_tokens: Maximum tokens allowed
|
||||||
|
model: Model for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Compressed context (may be same object if no compression needed)
|
||||||
|
"""
|
||||||
|
current_tokens = context.token_count or await self._count_tokens(
|
||||||
|
context.content, model
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_tokens <= max_tokens:
|
||||||
|
return context
|
||||||
|
|
||||||
|
# Choose strategy based on context type
|
||||||
|
strategy = self._get_strategy_for_type(context.get_type())
|
||||||
|
|
||||||
|
result = await self._truncation.truncate_to_tokens(
|
||||||
|
content=context.content,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
strategy=strategy,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update context with truncated content
|
||||||
|
context.content = result.content
|
||||||
|
context.token_count = result.truncated_tokens
|
||||||
|
context.metadata["truncated"] = True
|
||||||
|
context.metadata["original_tokens"] = result.original_tokens
|
||||||
|
|
||||||
|
return context
|
||||||
|
|
||||||
|
async def compress_contexts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
budget: "TokenBudget",
|
||||||
|
model: str | None = None,
|
||||||
|
) -> list[BaseContext]:
|
||||||
|
"""
|
||||||
|
Compress multiple contexts to fit within budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to potentially compress
|
||||||
|
budget: Token budget constraints
|
||||||
|
model: Model for token counting
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of contexts (compressed as needed)
|
||||||
|
"""
|
||||||
|
result: list[BaseContext] = []
|
||||||
|
|
||||||
|
for context in contexts:
|
||||||
|
context_type = context.get_type()
|
||||||
|
remaining = budget.remaining(context_type)
|
||||||
|
current_tokens = context.token_count or await self._count_tokens(
|
||||||
|
context.content, model
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_tokens > remaining:
|
||||||
|
# Need to compress
|
||||||
|
compressed = await self.compress_context(context, remaining, model)
|
||||||
|
result.append(compressed)
|
||||||
|
logger.debug(
|
||||||
|
f"Compressed {context_type.value} context from "
|
||||||
|
f"{current_tokens} to {compressed.token_count} tokens"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result.append(context)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_strategy_for_type(self, context_type: ContextType) -> str:
|
||||||
|
"""Get optimal truncation strategy for context type."""
|
||||||
|
strategies = {
|
||||||
|
ContextType.SYSTEM: "end", # Keep instructions at start
|
||||||
|
ContextType.TASK: "end", # Keep task description start
|
||||||
|
ContextType.KNOWLEDGE: "sentence", # Clean sentence boundaries
|
||||||
|
ContextType.CONVERSATION: "end", # Keep recent conversation
|
||||||
|
ContextType.TOOL: "middle", # Keep command and result summary
|
||||||
|
}
|
||||||
|
return strategies.get(context_type, "end")
|
||||||
|
|
||||||
|
async def _count_tokens(self, text: str, model: str | None = None) -> int:
|
||||||
|
"""Count tokens using calculator or estimation."""
|
||||||
|
if self._calculator is not None:
|
||||||
|
return await self._calculator.count_tokens(text, model)
|
||||||
|
# Use model-specific estimation for consistency
|
||||||
|
return _estimate_tokens(text, model)
|
||||||
380
backend/app/services/context/config.py
Normal file
380
backend/app/services/context/config.py
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
"""
|
||||||
|
Context Management Engine Configuration.
|
||||||
|
|
||||||
|
Provides Pydantic settings for context assembly,
|
||||||
|
token budget allocation, and caching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import Field, field_validator, model_validator
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class ContextSettings(BaseSettings):
|
||||||
|
"""
|
||||||
|
Configuration for the Context Management Engine.
|
||||||
|
|
||||||
|
All settings can be overridden via environment variables
|
||||||
|
with the CTX_ prefix.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Budget allocation percentages (must sum to 1.0)
|
||||||
|
budget_system: float = Field(
|
||||||
|
default=0.05,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage of budget for system prompts (5%)",
|
||||||
|
)
|
||||||
|
budget_task: float = Field(
|
||||||
|
default=0.10,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage of budget for task context (10%)",
|
||||||
|
)
|
||||||
|
budget_knowledge: float = Field(
|
||||||
|
default=0.40,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage of budget for RAG/knowledge (40%)",
|
||||||
|
)
|
||||||
|
budget_conversation: float = Field(
|
||||||
|
default=0.20,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage of budget for conversation history (20%)",
|
||||||
|
)
|
||||||
|
budget_tools: float = Field(
|
||||||
|
default=0.05,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage of budget for tool descriptions (5%)",
|
||||||
|
)
|
||||||
|
budget_response: float = Field(
|
||||||
|
default=0.15,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage reserved for response (15%)",
|
||||||
|
)
|
||||||
|
budget_buffer: float = Field(
|
||||||
|
default=0.05,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Percentage buffer for safety margin (5%)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scoring weights
|
||||||
|
scoring_relevance_weight: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Weight for relevance scoring",
|
||||||
|
)
|
||||||
|
scoring_recency_weight: float = Field(
|
||||||
|
default=0.3,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Weight for recency scoring",
|
||||||
|
)
|
||||||
|
scoring_priority_weight: float = Field(
|
||||||
|
default=0.2,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Weight for priority scoring",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recency decay settings
|
||||||
|
recency_decay_hours: float = Field(
|
||||||
|
default=24.0,
|
||||||
|
gt=0.0,
|
||||||
|
description="Hours until recency score decays to 50%",
|
||||||
|
)
|
||||||
|
recency_max_age_hours: float = Field(
|
||||||
|
default=168.0,
|
||||||
|
gt=0.0,
|
||||||
|
description="Hours until context is considered stale (7 days)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compression settings
|
||||||
|
compression_threshold: float = Field(
|
||||||
|
default=0.8,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Compress when budget usage exceeds this percentage",
|
||||||
|
)
|
||||||
|
truncation_marker: str = Field(
|
||||||
|
default="\n\n[...content truncated...]\n\n",
|
||||||
|
description="Marker text to insert where content was truncated",
|
||||||
|
)
|
||||||
|
truncation_preserve_ratio: float = Field(
|
||||||
|
default=0.7,
|
||||||
|
ge=0.1,
|
||||||
|
le=0.9,
|
||||||
|
description="Ratio of content to preserve from start in middle truncation (0.7 = 70% start, 30% end)",
|
||||||
|
)
|
||||||
|
truncation_min_content_length: int = Field(
|
||||||
|
default=100,
|
||||||
|
ge=10,
|
||||||
|
le=1000,
|
||||||
|
description="Minimum content length in characters before truncation applies",
|
||||||
|
)
|
||||||
|
summary_model_group: str = Field(
|
||||||
|
default="fast",
|
||||||
|
description="Model group to use for summarization",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Caching settings
|
||||||
|
cache_enabled: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable Redis caching for assembled contexts",
|
||||||
|
)
|
||||||
|
cache_ttl_seconds: int = Field(
|
||||||
|
default=3600,
|
||||||
|
ge=60,
|
||||||
|
le=86400,
|
||||||
|
description="Cache TTL in seconds (1 hour default, max 24 hours)",
|
||||||
|
)
|
||||||
|
cache_prefix: str = Field(
|
||||||
|
default="ctx",
|
||||||
|
description="Redis key prefix for context cache",
|
||||||
|
)
|
||||||
|
cache_memory_max_items: int = Field(
|
||||||
|
default=1000,
|
||||||
|
ge=100,
|
||||||
|
le=100000,
|
||||||
|
description="Maximum items in memory fallback cache when Redis unavailable",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Performance settings
|
||||||
|
max_assembly_time_ms: int = Field(
|
||||||
|
default=2000,
|
||||||
|
ge=10,
|
||||||
|
le=30000,
|
||||||
|
description="Maximum time for context assembly in milliseconds. "
|
||||||
|
"Should be high enough to accommodate MCP calls for knowledge retrieval.",
|
||||||
|
)
|
||||||
|
parallel_scoring: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Score contexts in parallel for better performance",
|
||||||
|
)
|
||||||
|
max_parallel_scores: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=50,
|
||||||
|
description="Maximum number of contexts to score in parallel",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Knowledge retrieval settings
|
||||||
|
knowledge_search_type: str = Field(
|
||||||
|
default="hybrid",
|
||||||
|
description="Default search type for knowledge retrieval",
|
||||||
|
)
|
||||||
|
knowledge_max_results: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
le=50,
|
||||||
|
description="Maximum knowledge chunks to retrieve",
|
||||||
|
)
|
||||||
|
knowledge_min_score: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Minimum relevance score for knowledge",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Relevance scoring settings
|
||||||
|
relevance_keyword_fallback_weight: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Maximum score for keyword-based fallback scoring (when semantic unavailable)",
|
||||||
|
)
|
||||||
|
relevance_semantic_max_chars: int = Field(
|
||||||
|
default=2000,
|
||||||
|
ge=100,
|
||||||
|
le=10000,
|
||||||
|
description="Maximum content length in chars for semantic similarity computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Diversity/ranking settings
|
||||||
|
diversity_max_per_source: int = Field(
|
||||||
|
default=3,
|
||||||
|
ge=1,
|
||||||
|
le=20,
|
||||||
|
description="Maximum contexts from the same source in diversity reranking",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conversation history settings
|
||||||
|
conversation_max_turns: int = Field(
|
||||||
|
default=20,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
description="Maximum conversation turns to include",
|
||||||
|
)
|
||||||
|
conversation_recent_priority: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Prioritize recent conversation turns",
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("knowledge_search_type")
|
||||||
|
@classmethod
|
||||||
|
def validate_search_type(cls, v: str) -> str:
|
||||||
|
"""Validate search type is valid."""
|
||||||
|
valid_types = {"semantic", "keyword", "hybrid"}
|
||||||
|
if v not in valid_types:
|
||||||
|
raise ValueError(f"search_type must be one of: {valid_types}")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_budget_allocation(self) -> "ContextSettings":
|
||||||
|
"""Validate that budget percentages sum to 1.0."""
|
||||||
|
total = (
|
||||||
|
self.budget_system
|
||||||
|
+ self.budget_task
|
||||||
|
+ self.budget_knowledge
|
||||||
|
+ self.budget_conversation
|
||||||
|
+ self.budget_tools
|
||||||
|
+ self.budget_response
|
||||||
|
+ self.budget_buffer
|
||||||
|
)
|
||||||
|
# Allow small floating point error
|
||||||
|
if abs(total - 1.0) > 0.001:
|
||||||
|
raise ValueError(
|
||||||
|
f"Budget percentages must sum to 1.0, got {total:.3f}. "
|
||||||
|
f"Current allocation: system={self.budget_system}, task={self.budget_task}, "
|
||||||
|
f"knowledge={self.budget_knowledge}, conversation={self.budget_conversation}, "
|
||||||
|
f"tools={self.budget_tools}, response={self.budget_response}, buffer={self.budget_buffer}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_scoring_weights(self) -> "ContextSettings":
|
||||||
|
"""Validate that scoring weights sum to 1.0."""
|
||||||
|
total = (
|
||||||
|
self.scoring_relevance_weight
|
||||||
|
+ self.scoring_recency_weight
|
||||||
|
+ self.scoring_priority_weight
|
||||||
|
)
|
||||||
|
# Allow small floating point error
|
||||||
|
if abs(total - 1.0) > 0.001:
|
||||||
|
raise ValueError(
|
||||||
|
f"Scoring weights must sum to 1.0, got {total:.3f}. "
|
||||||
|
f"Current weights: relevance={self.scoring_relevance_weight}, "
|
||||||
|
f"recency={self.scoring_recency_weight}, priority={self.scoring_priority_weight}"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_budget_allocation(self) -> dict[str, float]:
|
||||||
|
"""Get budget allocation as a dictionary."""
|
||||||
|
return {
|
||||||
|
"system": self.budget_system,
|
||||||
|
"task": self.budget_task,
|
||||||
|
"knowledge": self.budget_knowledge,
|
||||||
|
"conversation": self.budget_conversation,
|
||||||
|
"tools": self.budget_tools,
|
||||||
|
"response": self.budget_response,
|
||||||
|
"buffer": self.budget_buffer,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_scoring_weights(self) -> dict[str, float]:
|
||||||
|
"""Get scoring weights as a dictionary."""
|
||||||
|
return {
|
||||||
|
"relevance": self.scoring_relevance_weight,
|
||||||
|
"recency": self.scoring_recency_weight,
|
||||||
|
"priority": self.scoring_priority_weight,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert settings to dictionary for logging/debugging."""
|
||||||
|
return {
|
||||||
|
"budget": self.get_budget_allocation(),
|
||||||
|
"scoring": self.get_scoring_weights(),
|
||||||
|
"compression": {
|
||||||
|
"threshold": self.compression_threshold,
|
||||||
|
"summary_model_group": self.summary_model_group,
|
||||||
|
"truncation_marker": self.truncation_marker,
|
||||||
|
"truncation_preserve_ratio": self.truncation_preserve_ratio,
|
||||||
|
"truncation_min_content_length": self.truncation_min_content_length,
|
||||||
|
},
|
||||||
|
"cache": {
|
||||||
|
"enabled": self.cache_enabled,
|
||||||
|
"ttl_seconds": self.cache_ttl_seconds,
|
||||||
|
"prefix": self.cache_prefix,
|
||||||
|
"memory_max_items": self.cache_memory_max_items,
|
||||||
|
},
|
||||||
|
"performance": {
|
||||||
|
"max_assembly_time_ms": self.max_assembly_time_ms,
|
||||||
|
"parallel_scoring": self.parallel_scoring,
|
||||||
|
"max_parallel_scores": self.max_parallel_scores,
|
||||||
|
},
|
||||||
|
"knowledge": {
|
||||||
|
"search_type": self.knowledge_search_type,
|
||||||
|
"max_results": self.knowledge_max_results,
|
||||||
|
"min_score": self.knowledge_min_score,
|
||||||
|
},
|
||||||
|
"relevance": {
|
||||||
|
"keyword_fallback_weight": self.relevance_keyword_fallback_weight,
|
||||||
|
"semantic_max_chars": self.relevance_semantic_max_chars,
|
||||||
|
},
|
||||||
|
"diversity": {
|
||||||
|
"max_per_source": self.diversity_max_per_source,
|
||||||
|
},
|
||||||
|
"conversation": {
|
||||||
|
"max_turns": self.conversation_max_turns,
|
||||||
|
"recent_priority": self.conversation_recent_priority,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"env_prefix": "CTX_",
|
||||||
|
"env_file": "../.env",
|
||||||
|
"env_file_encoding": "utf-8",
|
||||||
|
"case_sensitive": False,
|
||||||
|
"extra": "ignore",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Thread-safe singleton pattern
|
||||||
|
_settings: ContextSettings | None = None
|
||||||
|
_settings_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_settings() -> ContextSettings:
|
||||||
|
"""
|
||||||
|
Get the global ContextSettings instance.
|
||||||
|
|
||||||
|
Thread-safe with double-checked locking pattern.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContextSettings instance
|
||||||
|
"""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
with _settings_lock:
|
||||||
|
if _settings is None:
|
||||||
|
_settings = ContextSettings()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def reset_context_settings() -> None:
|
||||||
|
"""
|
||||||
|
Reset the global settings instance.
|
||||||
|
|
||||||
|
Primarily used for testing.
|
||||||
|
"""
|
||||||
|
global _settings
|
||||||
|
with _settings_lock:
|
||||||
|
_settings = None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_default_settings() -> ContextSettings:
|
||||||
|
"""
|
||||||
|
Get default settings (cached).
|
||||||
|
|
||||||
|
Use this for read-only access to defaults.
|
||||||
|
For mutable access, use get_context_settings().
|
||||||
|
"""
|
||||||
|
return ContextSettings()
|
||||||
485
backend/app/services/context/engine.py
Normal file
485
backend/app/services/context/engine.py
Normal file
@@ -0,0 +1,485 @@
|
|||||||
|
"""
|
||||||
|
Context Management Engine.
|
||||||
|
|
||||||
|
Main orchestration layer for context assembly and optimization.
|
||||||
|
Provides a high-level API for assembling optimized context for LLM requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from .assembly import ContextPipeline
|
||||||
|
from .budget import BudgetAllocator, TokenBudget, TokenCalculator
|
||||||
|
from .cache import ContextCache
|
||||||
|
from .compression import ContextCompressor
|
||||||
|
from .config import ContextSettings, get_context_settings
|
||||||
|
from .prioritization import ContextRanker
|
||||||
|
from .scoring import CompositeScorer
|
||||||
|
from .types import (
|
||||||
|
AssembledContext,
|
||||||
|
BaseContext,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
ToolContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from redis.asyncio import Redis
|
||||||
|
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextEngine:
|
||||||
|
"""
|
||||||
|
Main context management engine.
|
||||||
|
|
||||||
|
Provides high-level API for context assembly and optimization.
|
||||||
|
Integrates all components: scoring, ranking, compression, formatting, and caching.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
engine = ContextEngine(mcp_manager=mcp, redis=redis)
|
||||||
|
|
||||||
|
# Assemble context for an LLM request
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="implement user authentication",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="You are an expert developer.",
|
||||||
|
knowledge_query="authentication best practices",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the assembled context
|
||||||
|
print(result.content)
|
||||||
|
print(f"Tokens: {result.total_tokens}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the context engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager for LLM Gateway/Knowledge Base
|
||||||
|
redis: Redis connection for caching
|
||||||
|
settings: Context settings
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
|
||||||
|
# Initialize components
|
||||||
|
self._calculator = TokenCalculator(mcp_manager=mcp_manager)
|
||||||
|
self._scorer = CompositeScorer(mcp_manager=mcp_manager, settings=self._settings)
|
||||||
|
self._ranker = ContextRanker(scorer=self._scorer, calculator=self._calculator)
|
||||||
|
self._compressor = ContextCompressor(calculator=self._calculator)
|
||||||
|
self._allocator = BudgetAllocator(self._settings)
|
||||||
|
self._cache = ContextCache(redis=redis, settings=self._settings)
|
||||||
|
|
||||||
|
# Pipeline for assembly
|
||||||
|
self._pipeline = ContextPipeline(
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
settings=self._settings,
|
||||||
|
calculator=self._calculator,
|
||||||
|
scorer=self._scorer,
|
||||||
|
ranker=self._ranker,
|
||||||
|
compressor=self._compressor,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""
|
||||||
|
Set MCP manager for all components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager
|
||||||
|
"""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
self._calculator.set_mcp_manager(mcp_manager)
|
||||||
|
self._scorer.set_mcp_manager(mcp_manager)
|
||||||
|
self._pipeline.set_mcp_manager(mcp_manager)
|
||||||
|
|
||||||
|
def set_redis(self, redis: "Redis") -> None:
|
||||||
|
"""
|
||||||
|
Set Redis connection for caching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
redis: Redis connection
|
||||||
|
"""
|
||||||
|
self._cache.set_redis(redis)
|
||||||
|
|
||||||
|
async def assemble_context(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
query: str,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
task_description: str | None = None,
|
||||||
|
knowledge_query: str | None = None,
|
||||||
|
knowledge_limit: int = 10,
|
||||||
|
conversation_history: list[dict[str, str]] | None = None,
|
||||||
|
tool_results: list[dict[str, Any]] | None = None,
|
||||||
|
custom_contexts: list[BaseContext] | None = None,
|
||||||
|
custom_budget: TokenBudget | None = None,
|
||||||
|
compress: bool = True,
|
||||||
|
format_output: bool = True,
|
||||||
|
use_cache: bool = True,
|
||||||
|
) -> AssembledContext:
|
||||||
|
"""
|
||||||
|
Assemble optimized context for an LLM request.
|
||||||
|
|
||||||
|
This is the main entry point for context management.
|
||||||
|
It gathers context from various sources, scores and ranks them,
|
||||||
|
compresses if needed, and formats for the target model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
query: User's query or current request
|
||||||
|
model: Target model name
|
||||||
|
max_tokens: Maximum context tokens (uses model default if None)
|
||||||
|
system_prompt: System prompt/instructions
|
||||||
|
task_description: Current task description
|
||||||
|
knowledge_query: Query for knowledge base search
|
||||||
|
knowledge_limit: Max number of knowledge results
|
||||||
|
conversation_history: List of {"role": str, "content": str}
|
||||||
|
tool_results: List of tool results to include
|
||||||
|
custom_contexts: Additional custom contexts
|
||||||
|
custom_budget: Custom token budget
|
||||||
|
compress: Whether to apply compression
|
||||||
|
format_output: Whether to format for the model
|
||||||
|
use_cache: Whether to use caching
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssembledContext with optimized content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssemblyTimeoutError: If assembly exceeds timeout
|
||||||
|
BudgetExceededError: If context exceeds budget
|
||||||
|
"""
|
||||||
|
# Gather all contexts
|
||||||
|
contexts: list[BaseContext] = []
|
||||||
|
|
||||||
|
# 1. System context
|
||||||
|
if system_prompt:
|
||||||
|
contexts.append(
|
||||||
|
SystemContext(
|
||||||
|
content=system_prompt,
|
||||||
|
source="system_prompt",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Task context
|
||||||
|
if task_description:
|
||||||
|
contexts.append(
|
||||||
|
TaskContext(
|
||||||
|
content=task_description,
|
||||||
|
source=f"task:{project_id}:{agent_id}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Knowledge context from Knowledge Base
|
||||||
|
if knowledge_query and self._mcp:
|
||||||
|
knowledge_contexts = await self._fetch_knowledge(
|
||||||
|
project_id=project_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
query=knowledge_query,
|
||||||
|
limit=knowledge_limit,
|
||||||
|
)
|
||||||
|
contexts.extend(knowledge_contexts)
|
||||||
|
|
||||||
|
# 4. Conversation history
|
||||||
|
if conversation_history:
|
||||||
|
contexts.extend(self._convert_conversation(conversation_history))
|
||||||
|
|
||||||
|
# 5. Tool results
|
||||||
|
if tool_results:
|
||||||
|
contexts.extend(self._convert_tool_results(tool_results))
|
||||||
|
|
||||||
|
# 6. Custom contexts
|
||||||
|
if custom_contexts:
|
||||||
|
contexts.extend(custom_contexts)
|
||||||
|
|
||||||
|
# Check cache if enabled
|
||||||
|
fingerprint: str | None = None
|
||||||
|
if use_cache and self._cache.is_enabled:
|
||||||
|
# Include project_id and agent_id for tenant isolation
|
||||||
|
fingerprint = self._cache.compute_fingerprint(
|
||||||
|
contexts, query, model, project_id=project_id, agent_id=agent_id
|
||||||
|
)
|
||||||
|
cached = await self._cache.get_assembled(fingerprint)
|
||||||
|
if cached:
|
||||||
|
logger.debug(f"Cache hit for context assembly: {fingerprint}")
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Run assembly pipeline
|
||||||
|
result = await self._pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query=query,
|
||||||
|
model=model,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
custom_budget=custom_budget,
|
||||||
|
compress=compress,
|
||||||
|
format_output=format_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cache result if enabled (reuse fingerprint computed above)
|
||||||
|
if use_cache and self._cache.is_enabled and fingerprint is not None:
|
||||||
|
await self._cache.set_assembled(fingerprint, result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _fetch_knowledge(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[KnowledgeContext]:
|
||||||
|
"""
|
||||||
|
Fetch relevant knowledge from Knowledge Base via MCP.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
query: Search query
|
||||||
|
limit: Maximum results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of KnowledgeContext instances
|
||||||
|
"""
|
||||||
|
if not self._mcp:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._mcp.call_tool(
|
||||||
|
"knowledge-base",
|
||||||
|
"search_knowledge",
|
||||||
|
{
|
||||||
|
"project_id": project_id,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"query": query,
|
||||||
|
"search_type": "hybrid",
|
||||||
|
"limit": limit,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check both ToolResult.success AND response success
|
||||||
|
if not result.success:
|
||||||
|
logger.warning(f"Knowledge search failed: {result.error}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not isinstance(result.data, dict) or not result.data.get(
|
||||||
|
"success", True
|
||||||
|
):
|
||||||
|
logger.warning("Knowledge search returned unsuccessful response")
|
||||||
|
return []
|
||||||
|
|
||||||
|
contexts = []
|
||||||
|
results = result.data.get("results", [])
|
||||||
|
for chunk in results:
|
||||||
|
contexts.append(
|
||||||
|
KnowledgeContext(
|
||||||
|
content=chunk.get("content", ""),
|
||||||
|
source=chunk.get("source_path", "unknown"),
|
||||||
|
relevance_score=chunk.get("score", 0.0),
|
||||||
|
metadata={
|
||||||
|
"chunk_id": chunk.get(
|
||||||
|
"id"
|
||||||
|
), # Server returns 'id' not 'chunk_id'
|
||||||
|
"document_id": chunk.get("document_id"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Fetched {len(contexts)} knowledge chunks for query: {query}")
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch knowledge: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _convert_conversation(
|
||||||
|
self,
|
||||||
|
history: list[dict[str, str]],
|
||||||
|
) -> list[ConversationContext]:
|
||||||
|
"""
|
||||||
|
Convert conversation history to ConversationContext instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
history: List of {"role": str, "content": str}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ConversationContext instances
|
||||||
|
"""
|
||||||
|
contexts = []
|
||||||
|
for i, turn in enumerate(history):
|
||||||
|
role_str = turn.get("role", "user").lower()
|
||||||
|
role = (
|
||||||
|
MessageRole.ASSISTANT if role_str == "assistant" else MessageRole.USER
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts.append(
|
||||||
|
ConversationContext(
|
||||||
|
content=turn.get("content", ""),
|
||||||
|
source=f"conversation:{i}",
|
||||||
|
role=role,
|
||||||
|
metadata={"role": role_str, "turn": i},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
def _convert_tool_results(
|
||||||
|
self,
|
||||||
|
results: list[dict[str, Any]],
|
||||||
|
) -> list[ToolContext]:
|
||||||
|
"""
|
||||||
|
Convert tool results to ToolContext instances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of tool result dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ToolContext instances
|
||||||
|
"""
|
||||||
|
contexts = []
|
||||||
|
for result in results:
|
||||||
|
tool_name = result.get("tool_name", "unknown")
|
||||||
|
content = result.get("content", result.get("result", ""))
|
||||||
|
|
||||||
|
# Handle dict content
|
||||||
|
if isinstance(content, dict):
|
||||||
|
import json
|
||||||
|
|
||||||
|
content = json.dumps(content, indent=2)
|
||||||
|
|
||||||
|
contexts.append(
|
||||||
|
ToolContext(
|
||||||
|
content=str(content),
|
||||||
|
source=f"tool:{tool_name}",
|
||||||
|
metadata={
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"status": result.get("status", "success"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
async def get_budget_for_model(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
) -> TokenBudget:
|
||||||
|
"""
|
||||||
|
Get the token budget for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
max_tokens: Optional max tokens override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenBudget instance
|
||||||
|
"""
|
||||||
|
if max_tokens:
|
||||||
|
return self._allocator.create_budget(max_tokens)
|
||||||
|
return self._allocator.create_budget_for_model(model)
|
||||||
|
|
||||||
|
async def count_tokens(
|
||||||
|
self,
|
||||||
|
content: str,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Count tokens in content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to count
|
||||||
|
model: Model for model-specific tokenization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Token count
|
||||||
|
"""
|
||||||
|
# Check cache first
|
||||||
|
cached = await self._cache.get_token_count(content, model)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
count = await self._calculator.count_tokens(content, model)
|
||||||
|
|
||||||
|
# Cache the result
|
||||||
|
await self._cache.set_token_count(content, count, model)
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
|
async def invalidate_cache(
|
||||||
|
self,
|
||||||
|
project_id: str | None = None,
|
||||||
|
pattern: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Invalidate cache entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Invalidate all cache for a project
|
||||||
|
pattern: Custom pattern to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of entries invalidated
|
||||||
|
"""
|
||||||
|
if pattern:
|
||||||
|
return await self._cache.invalidate(pattern)
|
||||||
|
elif project_id:
|
||||||
|
return await self._cache.invalidate(f"*{project_id}*")
|
||||||
|
else:
|
||||||
|
return await self._cache.clear_all()
|
||||||
|
|
||||||
|
async def get_stats(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get engine statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with engine stats
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"cache": await self._cache.get_stats(),
|
||||||
|
"settings": {
|
||||||
|
"compression_threshold": self._settings.compression_threshold,
|
||||||
|
"max_assembly_time_ms": self._settings.max_assembly_time_ms,
|
||||||
|
"cache_enabled": self._settings.cache_enabled,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience factory function
|
||||||
|
def create_context_engine(
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
redis: "Redis | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> ContextEngine:
|
||||||
|
"""
|
||||||
|
Create a context engine instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP client manager
|
||||||
|
redis: Redis connection
|
||||||
|
settings: Context settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured ContextEngine instance
|
||||||
|
"""
|
||||||
|
return ContextEngine(
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
redis=redis,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
354
backend/app/services/context/exceptions.py
Normal file
354
backend/app/services/context/exceptions.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
"""
|
||||||
|
Context Management Engine Exceptions.
|
||||||
|
|
||||||
|
Provides a hierarchy of exceptions for context assembly,
|
||||||
|
token budget management, and related operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ContextError(Exception):
|
||||||
|
"""
|
||||||
|
Base exception for all context management errors.
|
||||||
|
|
||||||
|
All context-related exceptions should inherit from this class
|
||||||
|
to allow for catch-all handling when needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, message: str, details: dict[str, Any] | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize context error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Human-readable error message
|
||||||
|
details: Optional dict with additional error context
|
||||||
|
"""
|
||||||
|
self.message = message
|
||||||
|
self.details = details or {}
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert exception to dictionary for logging/serialization."""
|
||||||
|
return {
|
||||||
|
"error_type": self.__class__.__name__,
|
||||||
|
"message": self.message,
|
||||||
|
"details": self.details,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetExceededError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when token budget is exceeded.
|
||||||
|
|
||||||
|
This occurs when the assembled context would exceed the
|
||||||
|
allocated token budget for a specific context type or total.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Token budget exceeded",
|
||||||
|
allocated: int = 0,
|
||||||
|
requested: int = 0,
|
||||||
|
context_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize budget exceeded error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
allocated: Tokens allocated for this context type
|
||||||
|
requested: Tokens requested
|
||||||
|
context_type: Type of context that exceeded budget
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {
|
||||||
|
"allocated": allocated,
|
||||||
|
"requested": requested,
|
||||||
|
"overage": requested - allocated,
|
||||||
|
}
|
||||||
|
if context_type:
|
||||||
|
details["context_type"] = context_type
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.allocated = allocated
|
||||||
|
self.requested = requested
|
||||||
|
self.context_type = context_type
|
||||||
|
|
||||||
|
|
||||||
|
class TokenCountError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when token counting fails.
|
||||||
|
|
||||||
|
This typically occurs when the LLM Gateway token counting
|
||||||
|
service is unavailable or returns an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to count tokens",
|
||||||
|
model: str | None = None,
|
||||||
|
text_length: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize token count error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
model: Model for which counting was attempted
|
||||||
|
text_length: Length of text that failed to count
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if model:
|
||||||
|
details["model"] = model
|
||||||
|
if text_length is not None:
|
||||||
|
details["text_length"] = text_length
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.model = model
|
||||||
|
self.text_length = text_length
|
||||||
|
|
||||||
|
|
||||||
|
class CompressionError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when context compression fails.
|
||||||
|
|
||||||
|
This can occur when summarization or truncation cannot
|
||||||
|
reduce content to fit within the budget.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to compress context",
|
||||||
|
original_tokens: int | None = None,
|
||||||
|
target_tokens: int | None = None,
|
||||||
|
achieved_tokens: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize compression error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
original_tokens: Tokens before compression
|
||||||
|
target_tokens: Target token count
|
||||||
|
achieved_tokens: Tokens achieved after compression attempt
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if original_tokens is not None:
|
||||||
|
details["original_tokens"] = original_tokens
|
||||||
|
if target_tokens is not None:
|
||||||
|
details["target_tokens"] = target_tokens
|
||||||
|
if achieved_tokens is not None:
|
||||||
|
details["achieved_tokens"] = achieved_tokens
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.original_tokens = original_tokens
|
||||||
|
self.target_tokens = target_tokens
|
||||||
|
self.achieved_tokens = achieved_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class AssemblyTimeoutError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when context assembly exceeds time limit.
|
||||||
|
|
||||||
|
Context assembly must complete within a configurable
|
||||||
|
time limit to maintain responsiveness.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Context assembly timed out",
|
||||||
|
timeout_ms: int = 0,
|
||||||
|
elapsed_ms: float = 0.0,
|
||||||
|
stage: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize assembly timeout error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
timeout_ms: Configured timeout in milliseconds
|
||||||
|
elapsed_ms: Actual elapsed time in milliseconds
|
||||||
|
stage: Pipeline stage where timeout occurred
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {
|
||||||
|
"timeout_ms": timeout_ms,
|
||||||
|
"elapsed_ms": round(elapsed_ms, 2),
|
||||||
|
}
|
||||||
|
if stage:
|
||||||
|
details["stage"] = stage
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.timeout_ms = timeout_ms
|
||||||
|
self.elapsed_ms = elapsed_ms
|
||||||
|
self.stage = stage
|
||||||
|
|
||||||
|
|
||||||
|
class ScoringError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when context scoring fails.
|
||||||
|
|
||||||
|
This occurs when relevance, recency, or priority scoring
|
||||||
|
encounters an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to score context",
|
||||||
|
scorer_type: str | None = None,
|
||||||
|
context_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize scoring error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
scorer_type: Type of scorer that failed
|
||||||
|
context_id: ID of context being scored
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if scorer_type:
|
||||||
|
details["scorer_type"] = scorer_type
|
||||||
|
if context_id:
|
||||||
|
details["context_id"] = context_id
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.scorer_type = scorer_type
|
||||||
|
self.context_id = context_id
|
||||||
|
|
||||||
|
|
||||||
|
class FormattingError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when context formatting fails.
|
||||||
|
|
||||||
|
This occurs when converting assembled context to
|
||||||
|
model-specific format fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Failed to format context",
|
||||||
|
model: str | None = None,
|
||||||
|
adapter: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize formatting error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
model: Target model
|
||||||
|
adapter: Adapter that failed
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if model:
|
||||||
|
details["model"] = model
|
||||||
|
if adapter:
|
||||||
|
details["adapter"] = adapter
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.model = model
|
||||||
|
self.adapter = adapter
|
||||||
|
|
||||||
|
|
||||||
|
class CacheError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when cache operations fail.
|
||||||
|
|
||||||
|
This is typically non-fatal and should be handled
|
||||||
|
gracefully by falling back to recomputation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Cache operation failed",
|
||||||
|
operation: str | None = None,
|
||||||
|
cache_key: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize cache error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
operation: Cache operation that failed (get, set, delete)
|
||||||
|
cache_key: Key involved in the failed operation
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if operation:
|
||||||
|
details["operation"] = operation
|
||||||
|
if cache_key:
|
||||||
|
details["cache_key"] = cache_key
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.operation = operation
|
||||||
|
self.cache_key = cache_key
|
||||||
|
|
||||||
|
|
||||||
|
class ContextNotFoundError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when expected context is not found.
|
||||||
|
|
||||||
|
This occurs when required context sources return
|
||||||
|
no results or are unavailable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Required context not found",
|
||||||
|
source: str | None = None,
|
||||||
|
query: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize context not found error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
source: Source that returned no results
|
||||||
|
query: Query used to search
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if source:
|
||||||
|
details["source"] = source
|
||||||
|
if query:
|
||||||
|
details["query"] = query
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.source = source
|
||||||
|
self.query = query
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidContextError(ContextError):
|
||||||
|
"""
|
||||||
|
Raised when context data is invalid.
|
||||||
|
|
||||||
|
This occurs when context content or metadata
|
||||||
|
fails validation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "Invalid context data",
|
||||||
|
field: str | None = None,
|
||||||
|
value: Any | None = None,
|
||||||
|
reason: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize invalid context error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message
|
||||||
|
field: Field that is invalid
|
||||||
|
value: Invalid value (may be redacted for security)
|
||||||
|
reason: Reason for invalidity
|
||||||
|
"""
|
||||||
|
details: dict[str, Any] = {}
|
||||||
|
if field:
|
||||||
|
details["field"] = field
|
||||||
|
if value is not None:
|
||||||
|
# Avoid logging potentially sensitive values
|
||||||
|
details["value_type"] = type(value).__name__
|
||||||
|
if reason:
|
||||||
|
details["reason"] = reason
|
||||||
|
|
||||||
|
super().__init__(message, details)
|
||||||
|
self.field = field
|
||||||
|
self.value = value
|
||||||
|
self.reason = reason
|
||||||
12
backend/app/services/context/prioritization/__init__.py
Normal file
12
backend/app/services/context/prioritization/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Context Prioritization Module.
|
||||||
|
|
||||||
|
Provides context ranking and selection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .ranker import ContextRanker, RankingResult
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextRanker",
|
||||||
|
"RankingResult",
|
||||||
|
]
|
||||||
374
backend/app/services/context/prioritization/ranker.py
Normal file
374
backend/app/services/context/prioritization/ranker.py
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
"""
|
||||||
|
Context Ranker for Context Management.
|
||||||
|
|
||||||
|
Ranks and selects contexts based on scores and budget constraints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..budget import TokenBudget, TokenCalculator
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..exceptions import BudgetExceededError
|
||||||
|
from ..scoring.composite import CompositeScorer, ScoredContext
|
||||||
|
from ..types import BaseContext, ContextPriority
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RankingResult:
|
||||||
|
"""Result of context ranking and selection."""
|
||||||
|
|
||||||
|
selected: list[ScoredContext]
|
||||||
|
excluded: list[ScoredContext]
|
||||||
|
total_tokens: int
|
||||||
|
selection_stats: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def selected_contexts(self) -> list[BaseContext]:
|
||||||
|
"""Get just the context objects (not scored wrappers)."""
|
||||||
|
return [s.context for s in self.selected]
|
||||||
|
|
||||||
|
|
||||||
|
class ContextRanker:
|
||||||
|
"""
|
||||||
|
Ranks and selects contexts within budget constraints.
|
||||||
|
|
||||||
|
Uses greedy selection to maximize total score
|
||||||
|
while respecting token budgets per context type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scorer: CompositeScorer | None = None,
|
||||||
|
calculator: TokenCalculator | None = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize context ranker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scorer: Composite scorer for scoring contexts
|
||||||
|
calculator: Token calculator for counting tokens
|
||||||
|
settings: Context settings (uses global if None)
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._scorer = scorer or CompositeScorer()
|
||||||
|
self._calculator = calculator or TokenCalculator()
|
||||||
|
|
||||||
|
def set_scorer(self, scorer: CompositeScorer) -> None:
|
||||||
|
"""Set the scorer."""
|
||||||
|
self._scorer = scorer
|
||||||
|
|
||||||
|
def set_calculator(self, calculator: TokenCalculator) -> None:
|
||||||
|
"""Set the token calculator."""
|
||||||
|
self._calculator = calculator
|
||||||
|
|
||||||
|
async def rank(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
budget: TokenBudget,
|
||||||
|
model: str | None = None,
|
||||||
|
ensure_required: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> RankingResult:
|
||||||
|
"""
|
||||||
|
Rank and select contexts within budget.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to rank
|
||||||
|
query: Query to rank against
|
||||||
|
budget: Token budget constraints
|
||||||
|
model: Model for token counting
|
||||||
|
ensure_required: If True, always include CRITICAL priority contexts
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RankingResult with selected and excluded contexts
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return RankingResult(
|
||||||
|
selected=[],
|
||||||
|
excluded=[],
|
||||||
|
total_tokens=0,
|
||||||
|
selection_stats={"total_contexts": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Ensure all contexts have token counts
|
||||||
|
await self._ensure_token_counts(contexts, model)
|
||||||
|
|
||||||
|
# 2. Score all contexts
|
||||||
|
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||||
|
|
||||||
|
# 3. Separate required (CRITICAL priority) from optional
|
||||||
|
required: list[ScoredContext] = []
|
||||||
|
optional: list[ScoredContext] = []
|
||||||
|
|
||||||
|
if ensure_required:
|
||||||
|
for sc in scored_contexts:
|
||||||
|
# CRITICAL priority (150) contexts are always included
|
||||||
|
if sc.context.priority >= ContextPriority.CRITICAL.value:
|
||||||
|
required.append(sc)
|
||||||
|
else:
|
||||||
|
optional.append(sc)
|
||||||
|
else:
|
||||||
|
optional = list(scored_contexts)
|
||||||
|
|
||||||
|
# 4. Sort optional by score (highest first)
|
||||||
|
optional.sort(reverse=True)
|
||||||
|
|
||||||
|
# 5. Greedy selection
|
||||||
|
selected: list[ScoredContext] = []
|
||||||
|
excluded: list[ScoredContext] = []
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
# Calculate the usable budget (total minus reserved portions)
|
||||||
|
usable_budget = budget.total - budget.response_reserve - budget.buffer
|
||||||
|
|
||||||
|
# Guard against invalid budget configuration
|
||||||
|
if usable_budget <= 0:
|
||||||
|
raise BudgetExceededError(
|
||||||
|
message=(
|
||||||
|
f"Invalid budget configuration: no usable tokens available. "
|
||||||
|
f"total={budget.total}, response_reserve={budget.response_reserve}, "
|
||||||
|
f"buffer={budget.buffer}"
|
||||||
|
),
|
||||||
|
allocated=budget.total,
|
||||||
|
requested=0,
|
||||||
|
context_type="CONFIGURATION_ERROR",
|
||||||
|
)
|
||||||
|
|
||||||
|
# First, try to fit required contexts
|
||||||
|
for sc in required:
|
||||||
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
|
context_type = sc.context.get_type()
|
||||||
|
|
||||||
|
if budget.can_fit(context_type, token_count):
|
||||||
|
budget.allocate(context_type, token_count)
|
||||||
|
selected.append(sc)
|
||||||
|
total_tokens += token_count
|
||||||
|
else:
|
||||||
|
# Force-fit CRITICAL contexts if needed, but check total budget first
|
||||||
|
if total_tokens + token_count > usable_budget:
|
||||||
|
# Even CRITICAL contexts cannot exceed total model context window
|
||||||
|
raise BudgetExceededError(
|
||||||
|
message=(
|
||||||
|
f"CRITICAL contexts exceed total budget. "
|
||||||
|
f"Context '{sc.context.source}' ({token_count} tokens) "
|
||||||
|
f"would exceed usable budget of {usable_budget} tokens."
|
||||||
|
),
|
||||||
|
allocated=usable_budget,
|
||||||
|
requested=total_tokens + token_count,
|
||||||
|
context_type="CRITICAL_OVERFLOW",
|
||||||
|
)
|
||||||
|
|
||||||
|
budget.allocate(context_type, token_count, force=True)
|
||||||
|
selected.append(sc)
|
||||||
|
total_tokens += token_count
|
||||||
|
logger.warning(
|
||||||
|
f"Force-fitted CRITICAL context: {sc.context.source} "
|
||||||
|
f"({token_count} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then, greedily add optional contexts
|
||||||
|
for sc in optional:
|
||||||
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
|
context_type = sc.context.get_type()
|
||||||
|
|
||||||
|
if budget.can_fit(context_type, token_count):
|
||||||
|
budget.allocate(context_type, token_count)
|
||||||
|
selected.append(sc)
|
||||||
|
total_tokens += token_count
|
||||||
|
else:
|
||||||
|
excluded.append(sc)
|
||||||
|
|
||||||
|
# Build stats
|
||||||
|
stats = {
|
||||||
|
"total_contexts": len(contexts),
|
||||||
|
"required_count": len(required),
|
||||||
|
"selected_count": len(selected),
|
||||||
|
"excluded_count": len(excluded),
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"by_type": self._count_by_type(selected),
|
||||||
|
}
|
||||||
|
|
||||||
|
return RankingResult(
|
||||||
|
selected=selected,
|
||||||
|
excluded=excluded,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
selection_stats=stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def rank_simple(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
max_tokens: int,
|
||||||
|
model: str | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[BaseContext]:
|
||||||
|
"""
|
||||||
|
Simple ranking without budget per type.
|
||||||
|
|
||||||
|
Selects top contexts by score until max tokens reached.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to rank
|
||||||
|
query: Query to rank against
|
||||||
|
max_tokens: Maximum total tokens
|
||||||
|
model: Model for token counting
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected contexts (in score order)
|
||||||
|
"""
|
||||||
|
if not contexts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Ensure token counts
|
||||||
|
await self._ensure_token_counts(contexts, model)
|
||||||
|
|
||||||
|
# Score all contexts
|
||||||
|
scored_contexts = await self._scorer.score_batch(contexts, query, **kwargs)
|
||||||
|
|
||||||
|
# Sort by score (highest first)
|
||||||
|
scored_contexts.sort(reverse=True)
|
||||||
|
|
||||||
|
# Greedy selection
|
||||||
|
selected: list[BaseContext] = []
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
for sc in scored_contexts:
|
||||||
|
token_count = self._get_valid_token_count(sc.context)
|
||||||
|
if total_tokens + token_count <= max_tokens:
|
||||||
|
selected.append(sc.context)
|
||||||
|
total_tokens += token_count
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
def _get_valid_token_count(self, context: BaseContext) -> int:
|
||||||
|
"""
|
||||||
|
Get validated token count from a context.
|
||||||
|
|
||||||
|
Ensures token_count is set (not None) and non-negative to prevent
|
||||||
|
budget bypass attacks where:
|
||||||
|
- None would be treated as 0 (allowing huge contexts to slip through)
|
||||||
|
- Negative values would corrupt budget tracking
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to get token count from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Valid non-negative token count
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If token_count is None or negative
|
||||||
|
"""
|
||||||
|
if context.token_count is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Context '{context.source}' has no token count. "
|
||||||
|
"Ensure _ensure_token_counts() is called before ranking."
|
||||||
|
)
|
||||||
|
if context.token_count < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Context '{context.source}' has invalid negative token count: "
|
||||||
|
f"{context.token_count}"
|
||||||
|
)
|
||||||
|
return context.token_count
|
||||||
|
|
||||||
|
async def _ensure_token_counts(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Ensure all contexts have token counts.
|
||||||
|
|
||||||
|
Counts tokens in parallel for contexts that don't have counts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to check
|
||||||
|
model: Model for token counting
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Find contexts needing counts
|
||||||
|
contexts_needing_counts = [ctx for ctx in contexts if ctx.token_count is None]
|
||||||
|
|
||||||
|
if not contexts_needing_counts:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Count all in parallel
|
||||||
|
tasks = [
|
||||||
|
self._calculator.count_tokens(ctx.content, model)
|
||||||
|
for ctx in contexts_needing_counts
|
||||||
|
]
|
||||||
|
counts = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Assign counts back
|
||||||
|
for ctx, count in zip(contexts_needing_counts, counts, strict=True):
|
||||||
|
ctx.token_count = count
|
||||||
|
|
||||||
|
def _count_by_type(
|
||||||
|
self, scored_contexts: list[ScoredContext]
|
||||||
|
) -> dict[str, dict[str, int]]:
|
||||||
|
"""Count selected contexts by type."""
|
||||||
|
by_type: dict[str, dict[str, int]] = {}
|
||||||
|
|
||||||
|
for sc in scored_contexts:
|
||||||
|
type_name = sc.context.get_type().value
|
||||||
|
if type_name not in by_type:
|
||||||
|
by_type[type_name] = {"count": 0, "tokens": 0}
|
||||||
|
by_type[type_name]["count"] += 1
|
||||||
|
# Use validated token count (already validated during ranking)
|
||||||
|
by_type[type_name]["tokens"] += sc.context.token_count or 0
|
||||||
|
|
||||||
|
return by_type
|
||||||
|
|
||||||
|
async def rerank_for_diversity(
|
||||||
|
self,
|
||||||
|
scored_contexts: list[ScoredContext],
|
||||||
|
max_per_source: int | None = None,
|
||||||
|
) -> list[ScoredContext]:
|
||||||
|
"""
|
||||||
|
Rerank to ensure source diversity.
|
||||||
|
|
||||||
|
Prevents too many items from the same source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scored_contexts: Already scored contexts
|
||||||
|
max_per_source: Maximum items per source (uses settings if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reranked contexts
|
||||||
|
"""
|
||||||
|
# Use provided value or fall back to settings
|
||||||
|
effective_max = (
|
||||||
|
max_per_source
|
||||||
|
if max_per_source is not None
|
||||||
|
else self._settings.diversity_max_per_source
|
||||||
|
)
|
||||||
|
|
||||||
|
source_counts: dict[str, int] = {}
|
||||||
|
result: list[ScoredContext] = []
|
||||||
|
deferred: list[ScoredContext] = []
|
||||||
|
|
||||||
|
for sc in scored_contexts:
|
||||||
|
source = sc.context.source
|
||||||
|
current_count = source_counts.get(source, 0)
|
||||||
|
|
||||||
|
if current_count < effective_max:
|
||||||
|
result.append(sc)
|
||||||
|
source_counts[source] = current_count + 1
|
||||||
|
else:
|
||||||
|
deferred.append(sc)
|
||||||
|
|
||||||
|
# Add deferred items at the end
|
||||||
|
result.extend(deferred)
|
||||||
|
return result
|
||||||
21
backend/app/services/context/scoring/__init__.py
Normal file
21
backend/app/services/context/scoring/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
Context Scoring Module.
|
||||||
|
|
||||||
|
Provides scoring strategies for context prioritization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import BaseScorer, ScorerProtocol
|
||||||
|
from .composite import CompositeScorer, ScoredContext
|
||||||
|
from .priority import PriorityScorer
|
||||||
|
from .recency import RecencyScorer
|
||||||
|
from .relevance import RelevanceScorer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseScorer",
|
||||||
|
"CompositeScorer",
|
||||||
|
"PriorityScorer",
|
||||||
|
"RecencyScorer",
|
||||||
|
"RelevanceScorer",
|
||||||
|
"ScoredContext",
|
||||||
|
"ScorerProtocol",
|
||||||
|
]
|
||||||
99
backend/app/services/context/scoring/base.py
Normal file
99
backend/app/services/context/scoring/base.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""
|
||||||
|
Base Scorer Protocol and Types.
|
||||||
|
|
||||||
|
Defines the interface for context scoring implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from ..types import BaseContext
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class ScorerProtocol(Protocol):
|
||||||
|
"""Protocol for context scorers."""
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score a context item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BaseScorer(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for context scorers.
|
||||||
|
|
||||||
|
Provides common functionality and interface for
|
||||||
|
different scoring strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight: float = 1.0) -> None:
|
||||||
|
"""
|
||||||
|
Initialize scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Weight for this scorer in composite scoring
|
||||||
|
"""
|
||||||
|
self._weight = weight
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight(self) -> float:
|
||||||
|
"""Get scorer weight."""
|
||||||
|
return self._weight
|
||||||
|
|
||||||
|
@weight.setter
|
||||||
|
def weight(self, value: float) -> None:
|
||||||
|
"""Set scorer weight."""
|
||||||
|
if not 0.0 <= value <= 1.0:
|
||||||
|
raise ValueError("Weight must be between 0.0 and 1.0")
|
||||||
|
self._weight = value
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score a context item.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def normalize_score(self, score: float) -> float:
|
||||||
|
"""
|
||||||
|
Normalize score to [0.0, 1.0] range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
score: Raw score
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized score
|
||||||
|
"""
|
||||||
|
return max(0.0, min(1.0, score))
|
||||||
368
backend/app/services/context/scoring/composite.py
Normal file
368
backend/app/services/context/scoring/composite.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
"""
|
||||||
|
Composite Scorer for Context Management.
|
||||||
|
|
||||||
|
Combines multiple scoring strategies with configurable weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..types import BaseContext
|
||||||
|
from .priority import PriorityScorer
|
||||||
|
from .recency import RecencyScorer
|
||||||
|
from .relevance import RelevanceScorer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoredContext:
|
||||||
|
"""Context with computed scores."""
|
||||||
|
|
||||||
|
context: BaseContext
|
||||||
|
composite_score: float
|
||||||
|
relevance_score: float = 0.0
|
||||||
|
recency_score: float = 0.0
|
||||||
|
priority_score: float = 0.0
|
||||||
|
|
||||||
|
def __lt__(self, other: "ScoredContext") -> bool:
|
||||||
|
"""Enable sorting by composite score."""
|
||||||
|
return self.composite_score < other.composite_score
|
||||||
|
|
||||||
|
def __gt__(self, other: "ScoredContext") -> bool:
|
||||||
|
"""Enable sorting by composite score."""
|
||||||
|
return self.composite_score > other.composite_score
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeScorer:
|
||||||
|
"""
|
||||||
|
Combines multiple scoring strategies.
|
||||||
|
|
||||||
|
Weights:
|
||||||
|
- relevance: How well content matches the query
|
||||||
|
- recency: How recent the content is
|
||||||
|
- priority: Explicit priority assignments
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
relevance_weight: float | None = None,
|
||||||
|
recency_weight: float | None = None,
|
||||||
|
priority_weight: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize composite scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP manager for semantic scoring
|
||||||
|
settings: Context settings (uses default if None)
|
||||||
|
relevance_weight: Override relevance weight
|
||||||
|
recency_weight: Override recency weight
|
||||||
|
priority_weight: Override priority weight
|
||||||
|
"""
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
weights = self._settings.get_scoring_weights()
|
||||||
|
|
||||||
|
self._relevance_weight = (
|
||||||
|
relevance_weight if relevance_weight is not None else weights["relevance"]
|
||||||
|
)
|
||||||
|
self._recency_weight = (
|
||||||
|
recency_weight if recency_weight is not None else weights["recency"]
|
||||||
|
)
|
||||||
|
self._priority_weight = (
|
||||||
|
priority_weight if priority_weight is not None else weights["priority"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize scorers
|
||||||
|
self._relevance_scorer = RelevanceScorer(
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
weight=self._relevance_weight,
|
||||||
|
)
|
||||||
|
self._recency_scorer = RecencyScorer(weight=self._recency_weight)
|
||||||
|
self._priority_scorer = PriorityScorer(weight=self._priority_weight)
|
||||||
|
|
||||||
|
# Per-context locks to prevent race conditions during parallel scoring
|
||||||
|
# Uses dict with (lock, last_used_time) tuples for cleanup
|
||||||
|
self._context_locks: dict[str, tuple[asyncio.Lock, float]] = {}
|
||||||
|
self._locks_lock = asyncio.Lock() # Lock to protect _context_locks access
|
||||||
|
self._max_locks = 1000 # Maximum locks to keep (prevent memory growth)
|
||||||
|
self._lock_ttl = 60.0 # Seconds before a lock can be cleaned up
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""Set MCP manager for semantic scoring."""
|
||||||
|
self._relevance_scorer.set_mcp_manager(mcp_manager)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights(self) -> dict[str, float]:
|
||||||
|
"""Get current scoring weights."""
|
||||||
|
return {
|
||||||
|
"relevance": self._relevance_weight,
|
||||||
|
"recency": self._recency_weight,
|
||||||
|
"priority": self._priority_weight,
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_weights(
|
||||||
|
self,
|
||||||
|
relevance: float | None = None,
|
||||||
|
recency: float | None = None,
|
||||||
|
priority: float | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update scoring weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relevance: New relevance weight
|
||||||
|
recency: New recency weight
|
||||||
|
priority: New priority weight
|
||||||
|
"""
|
||||||
|
if relevance is not None:
|
||||||
|
self._relevance_weight = max(0.0, min(1.0, relevance))
|
||||||
|
self._relevance_scorer.weight = self._relevance_weight
|
||||||
|
|
||||||
|
if recency is not None:
|
||||||
|
self._recency_weight = max(0.0, min(1.0, recency))
|
||||||
|
self._recency_scorer.weight = self._recency_weight
|
||||||
|
|
||||||
|
if priority is not None:
|
||||||
|
self._priority_weight = max(0.0, min(1.0, priority))
|
||||||
|
self._priority_scorer.weight = self._priority_weight
|
||||||
|
|
||||||
|
async def _get_context_lock(self, context_id: str) -> asyncio.Lock:
|
||||||
|
"""
|
||||||
|
Get or create a lock for a specific context.
|
||||||
|
|
||||||
|
Thread-safe access to per-context locks prevents race conditions
|
||||||
|
when the same context is scored concurrently. Includes automatic
|
||||||
|
cleanup of old locks to prevent memory growth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_id: The context ID to get a lock for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
asyncio.Lock for the context
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Fast path: check if lock exists without acquiring main lock
|
||||||
|
# NOTE: We only READ here - no writes to avoid race conditions
|
||||||
|
# with cleanup. The timestamp will be updated in the slow path
|
||||||
|
# if the lock is still valid.
|
||||||
|
lock_entry = self._context_locks.get(context_id)
|
||||||
|
if lock_entry is not None:
|
||||||
|
lock, _ = lock_entry
|
||||||
|
# Return the lock but defer timestamp update to avoid race
|
||||||
|
# The lock is still valid; timestamp update is best-effort
|
||||||
|
return lock
|
||||||
|
|
||||||
|
# Slow path: create lock or update timestamp while holding main lock
|
||||||
|
async with self._locks_lock:
|
||||||
|
# Double-check after acquiring lock - entry may have been
|
||||||
|
# created by another coroutine or deleted by cleanup
|
||||||
|
lock_entry = self._context_locks.get(context_id)
|
||||||
|
if lock_entry is not None:
|
||||||
|
lock, _ = lock_entry
|
||||||
|
# Safe to update timestamp here since we hold the lock
|
||||||
|
self._context_locks[context_id] = (lock, now)
|
||||||
|
return lock
|
||||||
|
|
||||||
|
# Cleanup old locks if we have too many
|
||||||
|
if len(self._context_locks) >= self._max_locks:
|
||||||
|
self._cleanup_old_locks(now)
|
||||||
|
|
||||||
|
# Create new lock
|
||||||
|
new_lock = asyncio.Lock()
|
||||||
|
self._context_locks[context_id] = (new_lock, now)
|
||||||
|
return new_lock
|
||||||
|
|
||||||
|
def _cleanup_old_locks(self, now: float) -> None:
|
||||||
|
"""
|
||||||
|
Remove old locks that haven't been used recently.
|
||||||
|
|
||||||
|
Called while holding _locks_lock. Removes locks older than _lock_ttl,
|
||||||
|
but only if they're not currently held.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
now: Current timestamp for age calculation
|
||||||
|
"""
|
||||||
|
cutoff = now - self._lock_ttl
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
for context_id, (lock, last_used) in self._context_locks.items():
|
||||||
|
# Only remove if old AND not currently held
|
||||||
|
if last_used < cutoff and not lock.locked():
|
||||||
|
to_remove.append(context_id)
|
||||||
|
|
||||||
|
# Remove oldest 50% if still over limit after TTL filtering
|
||||||
|
if len(self._context_locks) - len(to_remove) >= self._max_locks:
|
||||||
|
# Sort by last used time and mark oldest for removal
|
||||||
|
sorted_entries = sorted(
|
||||||
|
self._context_locks.items(),
|
||||||
|
key=lambda x: x[1][1], # Sort by last_used time
|
||||||
|
)
|
||||||
|
# Remove oldest 50% that aren't locked
|
||||||
|
target_remove = len(self._context_locks) // 2
|
||||||
|
for context_id, (lock, _) in sorted_entries:
|
||||||
|
if len(to_remove) >= target_remove:
|
||||||
|
break
|
||||||
|
if context_id not in to_remove and not lock.locked():
|
||||||
|
to_remove.append(context_id)
|
||||||
|
|
||||||
|
for context_id in to_remove:
|
||||||
|
del self._context_locks[context_id]
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
logger.debug(f"Cleaned up {len(to_remove)} context locks")
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute composite score for a context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composite score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
scored = await self.score_with_details(context, query, **kwargs)
|
||||||
|
return scored.composite_score
|
||||||
|
|
||||||
|
async def score_with_details(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ScoredContext:
|
||||||
|
"""
|
||||||
|
Compute composite score with individual scores.
|
||||||
|
|
||||||
|
Uses per-context locking to prevent race conditions when the same
|
||||||
|
context is scored concurrently in parallel scoring operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ScoredContext with all scores
|
||||||
|
"""
|
||||||
|
# Get lock for this specific context to prevent race conditions
|
||||||
|
# within concurrent scoring operations for the same query
|
||||||
|
context_lock = await self._get_context_lock(context.id)
|
||||||
|
|
||||||
|
async with context_lock:
|
||||||
|
# Compute individual scores in parallel
|
||||||
|
# Note: We do NOT cache scores on the context because scores are
|
||||||
|
# query-dependent. Caching without considering the query would
|
||||||
|
# return incorrect scores for different queries.
|
||||||
|
relevance_task = self._relevance_scorer.score(context, query, **kwargs)
|
||||||
|
recency_task = self._recency_scorer.score(context, query, **kwargs)
|
||||||
|
priority_task = self._priority_scorer.score(context, query, **kwargs)
|
||||||
|
|
||||||
|
relevance_score, recency_score, priority_score = await asyncio.gather(
|
||||||
|
relevance_task, recency_task, priority_task
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute weighted composite
|
||||||
|
total_weight = (
|
||||||
|
self._relevance_weight + self._recency_weight + self._priority_weight
|
||||||
|
)
|
||||||
|
|
||||||
|
if total_weight > 0:
|
||||||
|
composite = (
|
||||||
|
relevance_score * self._relevance_weight
|
||||||
|
+ recency_score * self._recency_weight
|
||||||
|
+ priority_score * self._priority_weight
|
||||||
|
) / total_weight
|
||||||
|
else:
|
||||||
|
composite = 0.0
|
||||||
|
|
||||||
|
return ScoredContext(
|
||||||
|
context=context,
|
||||||
|
composite_score=composite,
|
||||||
|
relevance_score=relevance_score,
|
||||||
|
recency_score=recency_score,
|
||||||
|
priority_score=priority_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
parallel: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[ScoredContext]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query to score against
|
||||||
|
parallel: Whether to score in parallel
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ScoredContext (same order as input)
|
||||||
|
"""
|
||||||
|
if parallel:
|
||||||
|
tasks = [self.score_with_details(ctx, query, **kwargs) for ctx in contexts]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
|
else:
|
||||||
|
results = []
|
||||||
|
for ctx in contexts:
|
||||||
|
scored = await self.score_with_details(ctx, query, **kwargs)
|
||||||
|
results.append(scored)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def rank(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
limit: int | None = None,
|
||||||
|
min_score: float = 0.0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[ScoredContext]:
|
||||||
|
"""
|
||||||
|
Score and rank contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to rank
|
||||||
|
query: Query to rank against
|
||||||
|
limit: Maximum number of results
|
||||||
|
min_score: Minimum score threshold
|
||||||
|
**kwargs: Additional scoring parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sorted list of ScoredContext (highest first)
|
||||||
|
"""
|
||||||
|
# Score all contexts
|
||||||
|
scored = await self.score_batch(contexts, query, **kwargs)
|
||||||
|
|
||||||
|
# Filter by minimum score
|
||||||
|
if min_score > 0:
|
||||||
|
scored = [s for s in scored if s.composite_score >= min_score]
|
||||||
|
|
||||||
|
# Sort by score (highest first)
|
||||||
|
scored.sort(reverse=True)
|
||||||
|
|
||||||
|
# Apply limit
|
||||||
|
if limit is not None:
|
||||||
|
scored = scored[:limit]
|
||||||
|
|
||||||
|
return scored
|
||||||
135
backend/app/services/context/scoring/priority.py
Normal file
135
backend/app/services/context/scoring/priority.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""
|
||||||
|
Priority Scorer for Context Management.
|
||||||
|
|
||||||
|
Scores context based on assigned priority levels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
from .base import BaseScorer
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityScorer(BaseScorer):
|
||||||
|
"""
|
||||||
|
Scores context based on priority levels.
|
||||||
|
|
||||||
|
Converts priority enum values to normalized scores.
|
||||||
|
Also applies type-based priority bonuses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default priority bonuses by context type
|
||||||
|
DEFAULT_TYPE_BONUSES: ClassVar[dict[ContextType, float]] = {
|
||||||
|
ContextType.SYSTEM: 0.2, # System prompts get a boost
|
||||||
|
ContextType.TASK: 0.15, # Current task is important
|
||||||
|
ContextType.TOOL: 0.1, # Recent tool results matter
|
||||||
|
ContextType.KNOWLEDGE: 0.0, # Knowledge scored by relevance
|
||||||
|
ContextType.CONVERSATION: 0.0, # Conversation scored by recency
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: float = 1.0,
|
||||||
|
type_bonuses: dict[ContextType, float] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize priority scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Scorer weight for composite scoring
|
||||||
|
type_bonuses: Optional context-type priority bonuses
|
||||||
|
"""
|
||||||
|
super().__init__(weight)
|
||||||
|
self._type_bonuses = type_bonuses or self.DEFAULT_TYPE_BONUSES.copy()
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score context based on priority.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query (not used for priority, kept for interface)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Priority score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
# Get base priority score
|
||||||
|
priority_value = context.priority
|
||||||
|
base_score = self._priority_to_score(priority_value)
|
||||||
|
|
||||||
|
# Apply type bonus
|
||||||
|
context_type = context.get_type()
|
||||||
|
bonus = self._type_bonuses.get(context_type, 0.0)
|
||||||
|
|
||||||
|
return self.normalize_score(base_score + bonus)
|
||||||
|
|
||||||
|
def _priority_to_score(self, priority: int) -> float:
|
||||||
|
"""
|
||||||
|
Convert priority value to normalized score.
|
||||||
|
|
||||||
|
Priority values (from ContextPriority):
|
||||||
|
- CRITICAL (100) -> 1.0
|
||||||
|
- HIGH (80) -> 0.8
|
||||||
|
- NORMAL (50) -> 0.5
|
||||||
|
- LOW (20) -> 0.2
|
||||||
|
- MINIMAL (0) -> 0.0
|
||||||
|
|
||||||
|
Args:
|
||||||
|
priority: Priority value (0-100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized score (0.0-1.0)
|
||||||
|
"""
|
||||||
|
# Clamp to valid range
|
||||||
|
clamped = max(0, min(100, priority))
|
||||||
|
return clamped / 100.0
|
||||||
|
|
||||||
|
def get_type_bonus(self, context_type: ContextType) -> float:
|
||||||
|
"""
|
||||||
|
Get priority bonus for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Bonus value
|
||||||
|
"""
|
||||||
|
return self._type_bonuses.get(context_type, 0.0)
|
||||||
|
|
||||||
|
def set_type_bonus(self, context_type: ContextType, bonus: float) -> None:
|
||||||
|
"""
|
||||||
|
Set priority bonus for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type
|
||||||
|
bonus: Bonus value (0.0-1.0)
|
||||||
|
"""
|
||||||
|
if not 0.0 <= bonus <= 1.0:
|
||||||
|
raise ValueError("Bonus must be between 0.0 and 1.0")
|
||||||
|
self._type_bonuses[context_type] = bonus
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query (not used)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (same order as input)
|
||||||
|
"""
|
||||||
|
# Priority scoring is fast, no async needed
|
||||||
|
return [await self.score(ctx, query, **kwargs) for ctx in contexts]
|
||||||
141
backend/app/services/context/scoring/recency.py
Normal file
141
backend/app/services/context/scoring/recency.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
Recency Scorer for Context Management.
|
||||||
|
|
||||||
|
Scores context based on how recent it is.
|
||||||
|
More recent content gets higher scores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ..types import BaseContext, ContextType
|
||||||
|
from .base import BaseScorer
|
||||||
|
|
||||||
|
|
||||||
|
class RecencyScorer(BaseScorer):
|
||||||
|
"""
|
||||||
|
Scores context based on recency.
|
||||||
|
|
||||||
|
Uses exponential decay to score content based on age.
|
||||||
|
More recent content scores higher.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: float = 1.0,
|
||||||
|
half_life_hours: float = 24.0,
|
||||||
|
type_half_lives: dict[ContextType, float] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize recency scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight: Scorer weight for composite scoring
|
||||||
|
half_life_hours: Default hours until score decays to 0.5
|
||||||
|
type_half_lives: Optional context-type-specific half lives
|
||||||
|
"""
|
||||||
|
super().__init__(weight)
|
||||||
|
self._half_life_hours = half_life_hours
|
||||||
|
self._type_half_lives = type_half_lives or {}
|
||||||
|
|
||||||
|
# Set sensible defaults for context types
|
||||||
|
if ContextType.CONVERSATION not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.CONVERSATION] = 1.0 # 1 hour
|
||||||
|
if ContextType.TOOL not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.TOOL] = 0.5 # 30 minutes
|
||||||
|
if ContextType.KNOWLEDGE not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.KNOWLEDGE] = 168.0 # 1 week
|
||||||
|
if ContextType.SYSTEM not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.SYSTEM] = 720.0 # 30 days
|
||||||
|
if ContextType.TASK not in self._type_half_lives:
|
||||||
|
self._type_half_lives[ContextType.TASK] = 24.0 # 1 day
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score context based on recency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query (not used for recency, kept for interface)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
- reference_time: Time to measure recency from (default: now)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Recency score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
reference_time = kwargs.get("reference_time")
|
||||||
|
if reference_time is None:
|
||||||
|
reference_time = datetime.now(UTC)
|
||||||
|
elif reference_time.tzinfo is None:
|
||||||
|
reference_time = reference_time.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
# Ensure context timestamp is timezone-aware
|
||||||
|
context_time = context.timestamp
|
||||||
|
if context_time.tzinfo is None:
|
||||||
|
context_time = context_time.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
# Calculate age in hours
|
||||||
|
age = reference_time - context_time
|
||||||
|
age_hours = max(0, age.total_seconds() / 3600)
|
||||||
|
|
||||||
|
# Get half-life for this context type
|
||||||
|
context_type = context.get_type()
|
||||||
|
half_life = self._type_half_lives.get(context_type, self._half_life_hours)
|
||||||
|
|
||||||
|
# Exponential decay
|
||||||
|
decay_factor = math.exp(-math.log(2) * age_hours / half_life)
|
||||||
|
|
||||||
|
return self.normalize_score(decay_factor)
|
||||||
|
|
||||||
|
def get_half_life(self, context_type: ContextType) -> float:
|
||||||
|
"""
|
||||||
|
Get half-life for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to get half-life for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Half-life in hours
|
||||||
|
"""
|
||||||
|
return self._type_half_lives.get(context_type, self._half_life_hours)
|
||||||
|
|
||||||
|
def set_half_life(self, context_type: ContextType, hours: float) -> None:
|
||||||
|
"""
|
||||||
|
Set half-life for a context type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_type: Context type to set half-life for
|
||||||
|
hours: Half-life in hours
|
||||||
|
"""
|
||||||
|
if hours <= 0:
|
||||||
|
raise ValueError("Half-life must be positive")
|
||||||
|
self._type_half_lives[context_type] = hours
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query (not used)
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (same order as input)
|
||||||
|
"""
|
||||||
|
scores = []
|
||||||
|
for context in contexts:
|
||||||
|
score = await self.score(context, query, **kwargs)
|
||||||
|
scores.append(score)
|
||||||
|
return scores
|
||||||
220
backend/app/services/context/scoring/relevance.py
Normal file
220
backend/app/services/context/scoring/relevance.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""
|
||||||
|
Relevance Scorer for Context Management.
|
||||||
|
|
||||||
|
Scores context based on semantic similarity to the query.
|
||||||
|
Uses Knowledge Base embeddings when available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ..config import ContextSettings, get_context_settings
|
||||||
|
from ..types import BaseContext, KnowledgeContext
|
||||||
|
from .base import BaseScorer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.mcp.client_manager import MCPClientManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RelevanceScorer(BaseScorer):
|
||||||
|
"""
|
||||||
|
Scores context based on relevance to query.
|
||||||
|
|
||||||
|
Uses multiple strategies:
|
||||||
|
1. Pre-computed scores (from RAG results)
|
||||||
|
2. MCP-based semantic similarity (via Knowledge Base)
|
||||||
|
3. Keyword matching fallback
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mcp_manager: "MCPClientManager | None" = None,
|
||||||
|
weight: float = 1.0,
|
||||||
|
keyword_fallback_weight: float | None = None,
|
||||||
|
semantic_max_chars: int | None = None,
|
||||||
|
settings: ContextSettings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize relevance scorer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mcp_manager: MCP manager for Knowledge Base calls
|
||||||
|
weight: Scorer weight for composite scoring
|
||||||
|
keyword_fallback_weight: Max score for keyword-based fallback (overrides settings)
|
||||||
|
semantic_max_chars: Max content length for semantic similarity (overrides settings)
|
||||||
|
settings: Context settings (uses global if None)
|
||||||
|
"""
|
||||||
|
super().__init__(weight)
|
||||||
|
self._settings = settings or get_context_settings()
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
|
||||||
|
# Use provided values or fall back to settings
|
||||||
|
self._keyword_fallback_weight = (
|
||||||
|
keyword_fallback_weight
|
||||||
|
if keyword_fallback_weight is not None
|
||||||
|
else self._settings.relevance_keyword_fallback_weight
|
||||||
|
)
|
||||||
|
self._semantic_max_chars = (
|
||||||
|
semantic_max_chars
|
||||||
|
if semantic_max_chars is not None
|
||||||
|
else self._settings.relevance_semantic_max_chars
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_mcp_manager(self, mcp_manager: "MCPClientManager") -> None:
|
||||||
|
"""Set MCP manager for semantic scoring."""
|
||||||
|
self._mcp = mcp_manager
|
||||||
|
|
||||||
|
async def score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Score context relevance to query.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Relevance score between 0.0 and 1.0
|
||||||
|
"""
|
||||||
|
# 1. Check for pre-computed relevance score
|
||||||
|
if (
|
||||||
|
isinstance(context, KnowledgeContext)
|
||||||
|
and context.relevance_score is not None
|
||||||
|
):
|
||||||
|
return self.normalize_score(context.relevance_score)
|
||||||
|
|
||||||
|
# 2. Check metadata for score
|
||||||
|
if "relevance_score" in context.metadata:
|
||||||
|
return self.normalize_score(context.metadata["relevance_score"])
|
||||||
|
|
||||||
|
if "score" in context.metadata:
|
||||||
|
return self.normalize_score(context.metadata["score"])
|
||||||
|
|
||||||
|
# 3. Try MCP-based semantic similarity (if compute_similarity tool is available)
|
||||||
|
# Note: This requires the knowledge-base MCP server to implement compute_similarity
|
||||||
|
if self._mcp is not None:
|
||||||
|
try:
|
||||||
|
score = await self._compute_semantic_similarity(context, query)
|
||||||
|
if score is not None:
|
||||||
|
return score
|
||||||
|
except Exception as e:
|
||||||
|
# Log at debug level since this is expected if compute_similarity
|
||||||
|
# tool is not implemented in the Knowledge Base server
|
||||||
|
logger.debug(
|
||||||
|
f"Semantic scoring unavailable, using keyword fallback: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Fall back to keyword matching
|
||||||
|
return self._compute_keyword_score(context, query)
|
||||||
|
|
||||||
|
async def _compute_semantic_similarity(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
) -> float | None:
|
||||||
|
"""
|
||||||
|
Compute semantic similarity using Knowledge Base embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to compare
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Similarity score or None if unavailable
|
||||||
|
"""
|
||||||
|
if self._mcp is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use Knowledge Base's search capability to compute similarity
|
||||||
|
result = await self._mcp.call_tool(
|
||||||
|
server="knowledge-base",
|
||||||
|
tool="compute_similarity",
|
||||||
|
args={
|
||||||
|
"text1": query,
|
||||||
|
"text2": context.content[
|
||||||
|
: self._semantic_max_chars
|
||||||
|
], # Limit content length
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.success and isinstance(result.data, dict):
|
||||||
|
similarity = result.data.get("similarity")
|
||||||
|
if similarity is not None:
|
||||||
|
return self.normalize_score(float(similarity))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Semantic similarity computation failed: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _compute_keyword_score(
|
||||||
|
self,
|
||||||
|
context: BaseContext,
|
||||||
|
query: str,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute relevance score based on keyword matching.
|
||||||
|
|
||||||
|
Simple but fast fallback when semantic search is unavailable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Context to score
|
||||||
|
query: Query to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Keyword-based relevance score
|
||||||
|
"""
|
||||||
|
if not query or not context.content:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Extract keywords from query
|
||||||
|
query_lower = query.lower()
|
||||||
|
content_lower = context.content.lower()
|
||||||
|
|
||||||
|
# Simple word tokenization
|
||||||
|
query_words = set(re.findall(r"\b\w{3,}\b", query_lower))
|
||||||
|
content_words = set(re.findall(r"\b\w{3,}\b", content_lower))
|
||||||
|
|
||||||
|
if not query_words:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Calculate overlap
|
||||||
|
common_words = query_words & content_words
|
||||||
|
overlap_ratio = len(common_words) / len(query_words)
|
||||||
|
|
||||||
|
# Apply fallback weight ceiling
|
||||||
|
return self.normalize_score(overlap_ratio * self._keyword_fallback_weight)
|
||||||
|
|
||||||
|
async def score_batch(
|
||||||
|
self,
|
||||||
|
contexts: list[BaseContext],
|
||||||
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> list[float]:
|
||||||
|
"""
|
||||||
|
Score multiple contexts in parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contexts: Contexts to score
|
||||||
|
query: Query to score against
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores (same order as input)
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
if not contexts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
tasks = [self.score(context, query, **kwargs) for context in contexts]
|
||||||
|
return await asyncio.gather(*tasks)
|
||||||
43
backend/app/services/context/types/__init__.py
Normal file
43
backend/app/services/context/types/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""
|
||||||
|
Context Types Module.
|
||||||
|
|
||||||
|
Provides all context types used in the Context Management Engine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import (
|
||||||
|
AssembledContext,
|
||||||
|
BaseContext,
|
||||||
|
ContextPriority,
|
||||||
|
ContextType,
|
||||||
|
)
|
||||||
|
from .conversation import (
|
||||||
|
ConversationContext,
|
||||||
|
MessageRole,
|
||||||
|
)
|
||||||
|
from .knowledge import KnowledgeContext
|
||||||
|
from .system import SystemContext
|
||||||
|
from .task import (
|
||||||
|
TaskComplexity,
|
||||||
|
TaskContext,
|
||||||
|
TaskStatus,
|
||||||
|
)
|
||||||
|
from .tool import (
|
||||||
|
ToolContext,
|
||||||
|
ToolResultStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AssembledContext",
|
||||||
|
"BaseContext",
|
||||||
|
"ContextPriority",
|
||||||
|
"ContextType",
|
||||||
|
"ConversationContext",
|
||||||
|
"KnowledgeContext",
|
||||||
|
"MessageRole",
|
||||||
|
"SystemContext",
|
||||||
|
"TaskComplexity",
|
||||||
|
"TaskContext",
|
||||||
|
"TaskStatus",
|
||||||
|
"ToolContext",
|
||||||
|
"ToolResultStatus",
|
||||||
|
]
|
||||||
347
backend/app/services/context/types/base.py
Normal file
347
backend/app/services/context/types/base.py
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
"""
|
||||||
|
Base Context Types and Enums.
|
||||||
|
|
||||||
|
Provides the foundation for all context types used in
|
||||||
|
the Context Management Engine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
||||||
|
class ContextType(str, Enum):
|
||||||
|
"""
|
||||||
|
Types of context that can be assembled.
|
||||||
|
|
||||||
|
Each type has specific handling, formatting, and
|
||||||
|
budget allocation rules.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SYSTEM = "system"
|
||||||
|
TASK = "task"
|
||||||
|
KNOWLEDGE = "knowledge"
|
||||||
|
CONVERSATION = "conversation"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, value: str) -> "ContextType":
|
||||||
|
"""
|
||||||
|
Convert string to ContextType.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: String value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContextType enum value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If value is not a valid context type
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return cls(value.lower())
|
||||||
|
except ValueError:
|
||||||
|
valid = ", ".join(t.value for t in cls)
|
||||||
|
raise ValueError(f"Invalid context type '{value}'. Valid types: {valid}")
|
||||||
|
|
||||||
|
|
||||||
|
class ContextPriority(int, Enum):
|
||||||
|
"""
|
||||||
|
Priority levels for context ordering.
|
||||||
|
|
||||||
|
Higher values indicate higher priority.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LOWEST = 0
|
||||||
|
LOW = 25
|
||||||
|
NORMAL = 50
|
||||||
|
HIGH = 75
|
||||||
|
HIGHEST = 100
|
||||||
|
CRITICAL = 150 # Never omit
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_int(cls, value: int) -> "ContextPriority":
|
||||||
|
"""
|
||||||
|
Get closest priority level for an integer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Integer priority value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Closest ContextPriority enum value
|
||||||
|
"""
|
||||||
|
priorities = sorted(cls, key=lambda p: p.value)
|
||||||
|
for priority in reversed(priorities):
|
||||||
|
if value >= priority.value:
|
||||||
|
return priority
|
||||||
|
return cls.LOWEST
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class BaseContext(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all context types.
|
||||||
|
|
||||||
|
Provides common fields and methods for context handling,
|
||||||
|
scoring, and serialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Required fields
|
||||||
|
content: str
|
||||||
|
source: str
|
||||||
|
|
||||||
|
# Optional fields with defaults
|
||||||
|
id: str = field(default_factory=lambda: str(uuid4()))
|
||||||
|
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
priority: int = field(default=ContextPriority.NORMAL.value)
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Computed/cached fields
|
||||||
|
_token_count: int | None = field(default=None, repr=False)
|
||||||
|
_score: float | None = field(default=None, repr=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token_count(self) -> int | None:
|
||||||
|
"""Get cached token count (None if not counted yet)."""
|
||||||
|
return self._token_count
|
||||||
|
|
||||||
|
@token_count.setter
|
||||||
|
def token_count(self, value: int) -> None:
|
||||||
|
"""Set token count."""
|
||||||
|
self._token_count = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score(self) -> float | None:
|
||||||
|
"""Get cached score (None if not scored yet)."""
|
||||||
|
return self._score
|
||||||
|
|
||||||
|
@score.setter
|
||||||
|
def score(self, value: float) -> None:
|
||||||
|
"""Set score (clamped to 0.0-1.0)."""
|
||||||
|
self._score = max(0.0, min(1.0, value))
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""
|
||||||
|
Get the type of this context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContextType enum value
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_age_seconds(self) -> float:
|
||||||
|
"""
|
||||||
|
Get age of context in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Age in seconds since creation
|
||||||
|
"""
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
delta = now - self.timestamp
|
||||||
|
return delta.total_seconds()
|
||||||
|
|
||||||
|
def get_age_hours(self) -> float:
|
||||||
|
"""
|
||||||
|
Get age of context in hours.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Age in hours since creation
|
||||||
|
"""
|
||||||
|
return self.get_age_seconds() / 3600
|
||||||
|
|
||||||
|
def is_stale(self, max_age_hours: float = 168.0) -> bool:
|
||||||
|
"""
|
||||||
|
Check if context is stale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_age_hours: Maximum age before considered stale (default 7 days)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if context is older than max_age_hours
|
||||||
|
"""
|
||||||
|
return self.get_age_hours() > max_age_hours
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert context to dictionary for serialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary representation
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"type": self.get_type().value,
|
||||||
|
"content": self.content,
|
||||||
|
"source": self.source,
|
||||||
|
"timestamp": self.timestamp.isoformat(),
|
||||||
|
"priority": self.priority,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"token_count": self._token_count,
|
||||||
|
"score": self._score,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "BaseContext":
|
||||||
|
"""
|
||||||
|
Create context from dictionary.
|
||||||
|
|
||||||
|
Note: Subclasses should override this to return correct type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary with context data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Context instance
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement from_dict")
|
||||||
|
|
||||||
|
def truncate(self, max_tokens: int, suffix: str = "... [truncated]") -> str:
|
||||||
|
"""
|
||||||
|
Truncate content to fit within token limit.
|
||||||
|
|
||||||
|
This is a rough estimation based on characters.
|
||||||
|
For accurate truncation, use the TokenCalculator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_tokens: Maximum tokens allowed
|
||||||
|
suffix: Suffix to append when truncated
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Truncated content
|
||||||
|
"""
|
||||||
|
if self._token_count is None or self._token_count <= max_tokens:
|
||||||
|
return self.content
|
||||||
|
|
||||||
|
# Rough estimation: 4 chars per token on average
|
||||||
|
estimated_chars = max_tokens * 4
|
||||||
|
suffix_chars = len(suffix)
|
||||||
|
|
||||||
|
if len(self.content) <= estimated_chars:
|
||||||
|
return self.content
|
||||||
|
|
||||||
|
truncated = self.content[: estimated_chars - suffix_chars]
|
||||||
|
# Try to break at word boundary
|
||||||
|
last_space = truncated.rfind(" ")
|
||||||
|
if last_space > estimated_chars * 0.8:
|
||||||
|
truncated = truncated[:last_space]
|
||||||
|
|
||||||
|
return truncated + suffix
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Hash based on ID for set/dict usage."""
|
||||||
|
return hash(self.id)
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
"""Equality based on ID."""
|
||||||
|
if not isinstance(other, BaseContext):
|
||||||
|
return False
|
||||||
|
return self.id == other.id
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssembledContext:
|
||||||
|
"""
|
||||||
|
Result of context assembly.
|
||||||
|
|
||||||
|
Contains the final formatted context ready for LLM consumption,
|
||||||
|
along with metadata about the assembly process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Main content
|
||||||
|
content: str
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
# Assembly metadata
|
||||||
|
context_count: int
|
||||||
|
excluded_count: int = 0
|
||||||
|
assembly_time_ms: float = 0.0
|
||||||
|
model: str = ""
|
||||||
|
|
||||||
|
# Included contexts (optional - for inspection)
|
||||||
|
contexts: list["BaseContext"] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Additional metadata from assembly
|
||||||
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Budget tracking
|
||||||
|
budget_total: int = 0
|
||||||
|
budget_used: int = 0
|
||||||
|
|
||||||
|
# Context breakdown
|
||||||
|
by_type: dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Cache info
|
||||||
|
cache_hit: bool = False
|
||||||
|
cache_key: str | None = None
|
||||||
|
|
||||||
|
# Aliases for backward compatibility
|
||||||
|
@property
|
||||||
|
def token_count(self) -> int:
|
||||||
|
"""Alias for total_tokens."""
|
||||||
|
return self.total_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def contexts_included(self) -> int:
|
||||||
|
"""Alias for context_count."""
|
||||||
|
return self.context_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def contexts_excluded(self) -> int:
|
||||||
|
"""Alias for excluded_count."""
|
||||||
|
return self.excluded_count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def budget_utilization(self) -> float:
|
||||||
|
"""Get budget utilization percentage."""
|
||||||
|
if self.budget_total == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.budget_used / self.budget_total
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"content": self.content,
|
||||||
|
"total_tokens": self.total_tokens,
|
||||||
|
"context_count": self.context_count,
|
||||||
|
"excluded_count": self.excluded_count,
|
||||||
|
"assembly_time_ms": round(self.assembly_time_ms, 2),
|
||||||
|
"model": self.model,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"budget_total": self.budget_total,
|
||||||
|
"budget_used": self.budget_used,
|
||||||
|
"budget_utilization": round(self.budget_utilization, 3),
|
||||||
|
"by_type": self.by_type,
|
||||||
|
"cache_hit": self.cache_hit,
|
||||||
|
"cache_key": self.cache_key,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_json(self) -> str:
|
||||||
|
"""Convert to JSON string."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
return json.dumps(self.to_dict())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json(cls, json_str: str) -> "AssembledContext":
|
||||||
|
"""Create from JSON string."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
data = json.loads(json_str)
|
||||||
|
return cls(
|
||||||
|
content=data["content"],
|
||||||
|
total_tokens=data["total_tokens"],
|
||||||
|
context_count=data["context_count"],
|
||||||
|
excluded_count=data.get("excluded_count", 0),
|
||||||
|
assembly_time_ms=data.get("assembly_time_ms", 0.0),
|
||||||
|
model=data.get("model", ""),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
budget_total=data.get("budget_total", 0),
|
||||||
|
budget_used=data.get("budget_used", 0),
|
||||||
|
by_type=data.get("by_type", {}),
|
||||||
|
cache_hit=data.get("cache_hit", False),
|
||||||
|
cache_key=data.get("cache_key"),
|
||||||
|
)
|
||||||
182
backend/app/services/context/types/conversation.py
Normal file
182
backend/app/services/context/types/conversation.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
"""
|
||||||
|
Conversation Context Type.
|
||||||
|
|
||||||
|
Represents conversation history for context continuity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRole(str, Enum):
|
||||||
|
"""Roles for conversation messages."""
|
||||||
|
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
SYSTEM = "system"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_string(cls, value: str) -> "MessageRole":
|
||||||
|
"""Convert string to MessageRole."""
|
||||||
|
try:
|
||||||
|
return cls(value.lower())
|
||||||
|
except ValueError:
|
||||||
|
# Default to user for unknown roles
|
||||||
|
return cls.USER
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class ConversationContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context from conversation history.
|
||||||
|
|
||||||
|
Represents a single turn in the conversation,
|
||||||
|
including user messages, assistant responses,
|
||||||
|
and tool results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Conversation-specific fields
|
||||||
|
role: MessageRole = field(default=MessageRole.USER)
|
||||||
|
turn_index: int = field(default=0)
|
||||||
|
session_id: str | None = field(default=None)
|
||||||
|
parent_message_id: str | None = field(default=None)
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return CONVERSATION context type."""
|
||||||
|
return ContextType.CONVERSATION
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with conversation-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"role": self.role.value,
|
||||||
|
"turn_index": self.turn_index,
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"parent_message_id": self.parent_message_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "ConversationContext":
|
||||||
|
"""Create ConversationContext from dictionary."""
|
||||||
|
role = data.get("role", "user")
|
||||||
|
if isinstance(role, str):
|
||||||
|
role = MessageRole.from_string(role)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data.get("source", "conversation"),
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
role=role,
|
||||||
|
turn_index=data.get("turn_index", 0),
|
||||||
|
session_id=data.get("session_id"),
|
||||||
|
parent_message_id=data.get("parent_message_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_message(
|
||||||
|
cls,
|
||||||
|
content: str,
|
||||||
|
role: str | MessageRole,
|
||||||
|
turn_index: int = 0,
|
||||||
|
session_id: str | None = None,
|
||||||
|
timestamp: datetime | None = None,
|
||||||
|
) -> "ConversationContext":
|
||||||
|
"""
|
||||||
|
Create ConversationContext from a message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Message content
|
||||||
|
role: Message role (user, assistant, system, tool)
|
||||||
|
turn_index: Position in conversation
|
||||||
|
session_id: Session identifier
|
||||||
|
timestamp: Message timestamp
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConversationContext instance
|
||||||
|
"""
|
||||||
|
if isinstance(role, str):
|
||||||
|
role = MessageRole.from_string(role)
|
||||||
|
|
||||||
|
# Recent messages have higher priority
|
||||||
|
priority = ContextPriority.NORMAL.value
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content=content,
|
||||||
|
source="conversation",
|
||||||
|
role=role,
|
||||||
|
turn_index=turn_index,
|
||||||
|
session_id=session_id,
|
||||||
|
timestamp=timestamp or datetime.now(UTC),
|
||||||
|
priority=priority,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_history(
|
||||||
|
cls,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> list["ConversationContext"]:
|
||||||
|
"""
|
||||||
|
Create multiple ConversationContexts from message history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dicts with 'role' and 'content'
|
||||||
|
session_id: Session identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ConversationContext instances
|
||||||
|
"""
|
||||||
|
contexts = []
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
ctx = cls.from_message(
|
||||||
|
content=msg.get("content", ""),
|
||||||
|
role=msg.get("role", "user"),
|
||||||
|
turn_index=i,
|
||||||
|
session_id=session_id,
|
||||||
|
timestamp=datetime.fromisoformat(msg["timestamp"])
|
||||||
|
if "timestamp" in msg
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
contexts.append(ctx)
|
||||||
|
return contexts
|
||||||
|
|
||||||
|
def is_user_message(self) -> bool:
|
||||||
|
"""Check if this is a user message."""
|
||||||
|
return self.role == MessageRole.USER
|
||||||
|
|
||||||
|
def is_assistant_message(self) -> bool:
|
||||||
|
"""Check if this is an assistant message."""
|
||||||
|
return self.role == MessageRole.ASSISTANT
|
||||||
|
|
||||||
|
def is_tool_result(self) -> bool:
|
||||||
|
"""Check if this is a tool result."""
|
||||||
|
return self.role == MessageRole.TOOL
|
||||||
|
|
||||||
|
def format_for_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Format message for inclusion in prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted message string
|
||||||
|
"""
|
||||||
|
role_labels = {
|
||||||
|
MessageRole.USER: "User",
|
||||||
|
MessageRole.ASSISTANT: "Assistant",
|
||||||
|
MessageRole.SYSTEM: "System",
|
||||||
|
MessageRole.TOOL: "Tool Result",
|
||||||
|
}
|
||||||
|
label = role_labels.get(self.role, "Unknown")
|
||||||
|
return f"{label}: {self.content}"
|
||||||
152
backend/app/services/context/types/knowledge.py
Normal file
152
backend/app/services/context/types/knowledge.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""
|
||||||
|
Knowledge Context Type.
|
||||||
|
|
||||||
|
Represents RAG results from the Knowledge Base MCP server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class KnowledgeContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context from knowledge base / RAG retrieval.
|
||||||
|
|
||||||
|
Knowledge context represents chunks retrieved from the
|
||||||
|
Knowledge Base MCP server, including:
|
||||||
|
- Code snippets
|
||||||
|
- Documentation
|
||||||
|
- Previous conversations
|
||||||
|
- External knowledge
|
||||||
|
|
||||||
|
Each chunk includes relevance scoring from the search.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Knowledge-specific fields
|
||||||
|
collection: str = field(default="default")
|
||||||
|
file_type: str | None = field(default=None)
|
||||||
|
chunk_index: int = field(default=0)
|
||||||
|
relevance_score: float = field(default=0.0)
|
||||||
|
search_query: str = field(default="")
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return KNOWLEDGE context type."""
|
||||||
|
return ContextType.KNOWLEDGE
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with knowledge-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"collection": self.collection,
|
||||||
|
"file_type": self.file_type,
|
||||||
|
"chunk_index": self.chunk_index,
|
||||||
|
"relevance_score": self.relevance_score,
|
||||||
|
"search_query": self.search_query,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "KnowledgeContext":
|
||||||
|
"""Create KnowledgeContext from dictionary."""
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data["source"],
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
collection=data.get("collection", "default"),
|
||||||
|
file_type=data.get("file_type"),
|
||||||
|
chunk_index=data.get("chunk_index", 0),
|
||||||
|
relevance_score=data.get("relevance_score", 0.0),
|
||||||
|
search_query=data.get("search_query", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_search_result(
|
||||||
|
cls,
|
||||||
|
result: dict[str, Any],
|
||||||
|
query: str,
|
||||||
|
) -> "KnowledgeContext":
|
||||||
|
"""
|
||||||
|
Create KnowledgeContext from a Knowledge Base search result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: Search result from Knowledge Base MCP
|
||||||
|
query: Search query used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KnowledgeContext instance
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
content=result.get("content", ""),
|
||||||
|
source=result.get("source_path", "unknown"),
|
||||||
|
collection=result.get("collection", "default"),
|
||||||
|
file_type=result.get("file_type"),
|
||||||
|
chunk_index=result.get("chunk_index", 0),
|
||||||
|
relevance_score=result.get("score", 0.0),
|
||||||
|
search_query=query,
|
||||||
|
metadata={
|
||||||
|
"chunk_id": result.get("id"),
|
||||||
|
"content_hash": result.get("content_hash"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_search_results(
|
||||||
|
cls,
|
||||||
|
results: list[dict[str, Any]],
|
||||||
|
query: str,
|
||||||
|
) -> list["KnowledgeContext"]:
|
||||||
|
"""
|
||||||
|
Create multiple KnowledgeContexts from search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of search results
|
||||||
|
query: Search query used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of KnowledgeContext instances
|
||||||
|
"""
|
||||||
|
return [cls.from_search_result(r, query) for r in results]
|
||||||
|
|
||||||
|
def is_code(self) -> bool:
|
||||||
|
"""Check if this is code content."""
|
||||||
|
code_types = {
|
||||||
|
"python",
|
||||||
|
"javascript",
|
||||||
|
"typescript",
|
||||||
|
"go",
|
||||||
|
"rust",
|
||||||
|
"java",
|
||||||
|
"c",
|
||||||
|
"cpp",
|
||||||
|
}
|
||||||
|
return self.file_type is not None and self.file_type.lower() in code_types
|
||||||
|
|
||||||
|
def is_documentation(self) -> bool:
|
||||||
|
"""Check if this is documentation content."""
|
||||||
|
doc_types = {"markdown", "rst", "txt", "md"}
|
||||||
|
return self.file_type is not None and self.file_type.lower() in doc_types
|
||||||
|
|
||||||
|
def get_formatted_source(self) -> str:
|
||||||
|
"""
|
||||||
|
Get a formatted source string for display.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted source string
|
||||||
|
"""
|
||||||
|
parts = [self.source]
|
||||||
|
if self.file_type:
|
||||||
|
parts.append(f"({self.file_type})")
|
||||||
|
if self.collection != "default":
|
||||||
|
parts.insert(0, f"[{self.collection}]")
|
||||||
|
return " ".join(parts)
|
||||||
138
backend/app/services/context/types/system.py
Normal file
138
backend/app/services/context/types/system.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""
|
||||||
|
System Context Type.
|
||||||
|
|
||||||
|
Represents system prompts, instructions, and agent personas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class SystemContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context for system prompts and instructions.
|
||||||
|
|
||||||
|
System context typically includes:
|
||||||
|
- Agent persona and role definitions
|
||||||
|
- Behavioral instructions
|
||||||
|
- Safety guidelines
|
||||||
|
- Output format requirements
|
||||||
|
|
||||||
|
System context is usually high priority and should
|
||||||
|
rarely be truncated or omitted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# System context specific fields
|
||||||
|
role: str = field(default="assistant")
|
||||||
|
instructions_type: str = field(default="general")
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Set high priority for system context."""
|
||||||
|
# System context defaults to high priority
|
||||||
|
if self.priority == ContextPriority.NORMAL.value:
|
||||||
|
self.priority = ContextPriority.HIGH.value
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return SYSTEM context type."""
|
||||||
|
return ContextType.SYSTEM
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with system-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"role": self.role,
|
||||||
|
"instructions_type": self.instructions_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "SystemContext":
|
||||||
|
"""Create SystemContext from dictionary."""
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data["source"],
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
role=data.get("role", "assistant"),
|
||||||
|
instructions_type=data.get("instructions_type", "general"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_persona(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
capabilities: list[str] | None = None,
|
||||||
|
constraints: list[str] | None = None,
|
||||||
|
) -> "SystemContext":
|
||||||
|
"""
|
||||||
|
Create a persona system context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Agent name/role
|
||||||
|
description: Role description
|
||||||
|
capabilities: List of things the agent can do
|
||||||
|
constraints: List of limitations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SystemContext with formatted persona
|
||||||
|
"""
|
||||||
|
parts = [f"You are {name}.", "", description]
|
||||||
|
|
||||||
|
if capabilities:
|
||||||
|
parts.append("")
|
||||||
|
parts.append("You can:")
|
||||||
|
for cap in capabilities:
|
||||||
|
parts.append(f"- {cap}")
|
||||||
|
|
||||||
|
if constraints:
|
||||||
|
parts.append("")
|
||||||
|
parts.append("You must not:")
|
||||||
|
for constraint in constraints:
|
||||||
|
parts.append(f"- {constraint}")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content="\n".join(parts),
|
||||||
|
source="persona_builder",
|
||||||
|
role=name.lower().replace(" ", "_"),
|
||||||
|
instructions_type="persona",
|
||||||
|
priority=ContextPriority.HIGHEST.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_instructions(
|
||||||
|
cls,
|
||||||
|
instructions: str | list[str],
|
||||||
|
source: str = "instructions",
|
||||||
|
) -> "SystemContext":
|
||||||
|
"""
|
||||||
|
Create an instructions system context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instructions: Instructions string or list of instruction strings
|
||||||
|
source: Source identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SystemContext with instructions
|
||||||
|
"""
|
||||||
|
if isinstance(instructions, list):
|
||||||
|
content = "\n".join(f"- {inst}" for inst in instructions)
|
||||||
|
else:
|
||||||
|
content = instructions
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content=content,
|
||||||
|
source=source,
|
||||||
|
instructions_type="instructions",
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
)
|
||||||
193
backend/app/services/context/types/task.py
Normal file
193
backend/app/services/context/types/task.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""
|
||||||
|
Task Context Type.
|
||||||
|
|
||||||
|
Represents the current task or objective for the agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(str, Enum):
|
||||||
|
"""Status of a task."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
IN_PROGRESS = "in_progress"
|
||||||
|
BLOCKED = "blocked"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskComplexity(str, Enum):
|
||||||
|
"""Complexity level of a task."""
|
||||||
|
|
||||||
|
TRIVIAL = "trivial"
|
||||||
|
SIMPLE = "simple"
|
||||||
|
MODERATE = "moderate"
|
||||||
|
COMPLEX = "complex"
|
||||||
|
VERY_COMPLEX = "very_complex"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class TaskContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context for the current task or objective.
|
||||||
|
|
||||||
|
Task context provides information about what the agent
|
||||||
|
should accomplish, including:
|
||||||
|
- Task description and goals
|
||||||
|
- Acceptance criteria
|
||||||
|
- Constraints and requirements
|
||||||
|
- Related issue/ticket information
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Task-specific fields
|
||||||
|
title: str = field(default="")
|
||||||
|
status: TaskStatus = field(default=TaskStatus.PENDING)
|
||||||
|
complexity: TaskComplexity = field(default=TaskComplexity.MODERATE)
|
||||||
|
issue_id: str | None = field(default=None)
|
||||||
|
project_id: str | None = field(default=None)
|
||||||
|
acceptance_criteria: list[str] = field(default_factory=list)
|
||||||
|
constraints: list[str] = field(default_factory=list)
|
||||||
|
parent_task_id: str | None = field(default=None)
|
||||||
|
|
||||||
|
# Note: TaskContext should typically have HIGH priority,
|
||||||
|
# but we don't auto-promote to allow explicit priority setting.
|
||||||
|
# Use TaskContext.create() for default HIGH priority behavior.
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return TASK context type."""
|
||||||
|
return ContextType.TASK
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with task-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"title": self.title,
|
||||||
|
"status": self.status.value,
|
||||||
|
"complexity": self.complexity.value,
|
||||||
|
"issue_id": self.issue_id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"acceptance_criteria": self.acceptance_criteria,
|
||||||
|
"constraints": self.constraints,
|
||||||
|
"parent_task_id": self.parent_task_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "TaskContext":
|
||||||
|
"""Create TaskContext from dictionary."""
|
||||||
|
status = data.get("status", "pending")
|
||||||
|
if isinstance(status, str):
|
||||||
|
status = TaskStatus(status)
|
||||||
|
|
||||||
|
complexity = data.get("complexity", "moderate")
|
||||||
|
if isinstance(complexity, str):
|
||||||
|
complexity = TaskComplexity(complexity)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data.get("source", "task"),
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.HIGH.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
title=data.get("title", ""),
|
||||||
|
status=status,
|
||||||
|
complexity=complexity,
|
||||||
|
issue_id=data.get("issue_id"),
|
||||||
|
project_id=data.get("project_id"),
|
||||||
|
acceptance_criteria=data.get("acceptance_criteria", []),
|
||||||
|
constraints=data.get("constraints", []),
|
||||||
|
parent_task_id=data.get("parent_task_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(
|
||||||
|
cls,
|
||||||
|
title: str,
|
||||||
|
description: str,
|
||||||
|
acceptance_criteria: list[str] | None = None,
|
||||||
|
constraints: list[str] | None = None,
|
||||||
|
issue_id: str | None = None,
|
||||||
|
project_id: str | None = None,
|
||||||
|
complexity: TaskComplexity | str = TaskComplexity.MODERATE,
|
||||||
|
) -> "TaskContext":
|
||||||
|
"""
|
||||||
|
Create a task context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: Task title
|
||||||
|
description: Task description
|
||||||
|
acceptance_criteria: List of acceptance criteria
|
||||||
|
constraints: List of constraints
|
||||||
|
issue_id: Related issue ID
|
||||||
|
project_id: Project ID
|
||||||
|
complexity: Task complexity
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TaskContext instance
|
||||||
|
"""
|
||||||
|
if isinstance(complexity, str):
|
||||||
|
complexity = TaskComplexity(complexity)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content=description,
|
||||||
|
source=f"task:{issue_id}" if issue_id else "task",
|
||||||
|
title=title,
|
||||||
|
status=TaskStatus.IN_PROGRESS,
|
||||||
|
complexity=complexity,
|
||||||
|
issue_id=issue_id,
|
||||||
|
project_id=project_id,
|
||||||
|
acceptance_criteria=acceptance_criteria or [],
|
||||||
|
constraints=constraints or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_for_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Format task for inclusion in prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted task string
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if self.title:
|
||||||
|
parts.append(f"Task: {self.title}")
|
||||||
|
parts.append("")
|
||||||
|
|
||||||
|
parts.append(self.content)
|
||||||
|
|
||||||
|
if self.acceptance_criteria:
|
||||||
|
parts.append("")
|
||||||
|
parts.append("Acceptance Criteria:")
|
||||||
|
for criterion in self.acceptance_criteria:
|
||||||
|
parts.append(f"- {criterion}")
|
||||||
|
|
||||||
|
if self.constraints:
|
||||||
|
parts.append("")
|
||||||
|
parts.append("Constraints:")
|
||||||
|
for constraint in self.constraints:
|
||||||
|
parts.append(f"- {constraint}")
|
||||||
|
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
"""Check if task is currently active."""
|
||||||
|
return self.status in (TaskStatus.PENDING, TaskStatus.IN_PROGRESS)
|
||||||
|
|
||||||
|
def is_complete(self) -> bool:
|
||||||
|
"""Check if task is complete."""
|
||||||
|
return self.status == TaskStatus.COMPLETED
|
||||||
|
|
||||||
|
def is_blocked(self) -> bool:
|
||||||
|
"""Check if task is blocked."""
|
||||||
|
return self.status == TaskStatus.BLOCKED
|
||||||
211
backend/app/services/context/types/tool.py
Normal file
211
backend/app/services/context/types/tool.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""
|
||||||
|
Tool Context Type.
|
||||||
|
|
||||||
|
Represents available tools and recent tool execution results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .base import BaseContext, ContextPriority, ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class ToolResultStatus(str, Enum):
|
||||||
|
"""Status of a tool execution result."""
|
||||||
|
|
||||||
|
SUCCESS = "success"
|
||||||
|
ERROR = "error"
|
||||||
|
TIMEOUT = "timeout"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class ToolContext(BaseContext):
|
||||||
|
"""
|
||||||
|
Context for tools and tool execution results.
|
||||||
|
|
||||||
|
Tool context includes:
|
||||||
|
- Tool descriptions and parameters
|
||||||
|
- Recent tool execution results
|
||||||
|
- Tool availability information
|
||||||
|
|
||||||
|
This helps the LLM understand what tools are available
|
||||||
|
and what results previous tool calls produced.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Tool-specific fields
|
||||||
|
tool_name: str = field(default="")
|
||||||
|
tool_description: str = field(default="")
|
||||||
|
is_result: bool = field(default=False)
|
||||||
|
result_status: ToolResultStatus | None = field(default=None)
|
||||||
|
execution_time_ms: float | None = field(default=None)
|
||||||
|
parameters: dict[str, Any] = field(default_factory=dict)
|
||||||
|
server_name: str | None = field(default=None)
|
||||||
|
|
||||||
|
def get_type(self) -> ContextType:
|
||||||
|
"""Return TOOL context type."""
|
||||||
|
return ContextType.TOOL
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary with tool-specific fields."""
|
||||||
|
base = super().to_dict()
|
||||||
|
base.update(
|
||||||
|
{
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"tool_description": self.tool_description,
|
||||||
|
"is_result": self.is_result,
|
||||||
|
"result_status": self.result_status.value
|
||||||
|
if self.result_status
|
||||||
|
else None,
|
||||||
|
"execution_time_ms": self.execution_time_ms,
|
||||||
|
"parameters": self.parameters,
|
||||||
|
"server_name": self.server_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return base
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "ToolContext":
|
||||||
|
"""Create ToolContext from dictionary."""
|
||||||
|
result_status = data.get("result_status")
|
||||||
|
if isinstance(result_status, str):
|
||||||
|
result_status = ToolResultStatus(result_status)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=data.get("id", ""),
|
||||||
|
content=data["content"],
|
||||||
|
source=data.get("source", "tool"),
|
||||||
|
timestamp=datetime.fromisoformat(data["timestamp"])
|
||||||
|
if isinstance(data.get("timestamp"), str)
|
||||||
|
else data.get("timestamp", datetime.now(UTC)),
|
||||||
|
priority=data.get("priority", ContextPriority.NORMAL.value),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
tool_name=data.get("tool_name", ""),
|
||||||
|
tool_description=data.get("tool_description", ""),
|
||||||
|
is_result=data.get("is_result", False),
|
||||||
|
result_status=result_status,
|
||||||
|
execution_time_ms=data.get("execution_time_ms"),
|
||||||
|
parameters=data.get("parameters", {}),
|
||||||
|
server_name=data.get("server_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tool_definition(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
parameters: dict[str, Any] | None = None,
|
||||||
|
server_name: str | None = None,
|
||||||
|
) -> "ToolContext":
|
||||||
|
"""
|
||||||
|
Create a ToolContext from a tool definition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Tool name
|
||||||
|
description: Tool description
|
||||||
|
parameters: Tool parameter schema
|
||||||
|
server_name: MCP server name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolContext instance
|
||||||
|
"""
|
||||||
|
# Format content as tool documentation
|
||||||
|
content_parts = [f"Tool: {name}", "", description]
|
||||||
|
|
||||||
|
if parameters:
|
||||||
|
content_parts.append("")
|
||||||
|
content_parts.append("Parameters:")
|
||||||
|
for param_name, param_info in parameters.items():
|
||||||
|
param_type = param_info.get("type", "any")
|
||||||
|
param_desc = param_info.get("description", "")
|
||||||
|
required = param_info.get("required", False)
|
||||||
|
req_marker = " (required)" if required else ""
|
||||||
|
content_parts.append(f" - {param_name}: {param_type}{req_marker}")
|
||||||
|
if param_desc:
|
||||||
|
content_parts.append(f" {param_desc}")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content="\n".join(content_parts),
|
||||||
|
source=f"tool:{server_name}:{name}" if server_name else f"tool:{name}",
|
||||||
|
tool_name=name,
|
||||||
|
tool_description=description,
|
||||||
|
is_result=False,
|
||||||
|
parameters=parameters or {},
|
||||||
|
server_name=server_name,
|
||||||
|
priority=ContextPriority.LOW.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_tool_result(
|
||||||
|
cls,
|
||||||
|
tool_name: str,
|
||||||
|
result: Any,
|
||||||
|
status: ToolResultStatus = ToolResultStatus.SUCCESS,
|
||||||
|
execution_time_ms: float | None = None,
|
||||||
|
parameters: dict[str, Any] | None = None,
|
||||||
|
server_name: str | None = None,
|
||||||
|
) -> "ToolContext":
|
||||||
|
"""
|
||||||
|
Create a ToolContext from a tool execution result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool that was executed
|
||||||
|
result: Result content (will be converted to string)
|
||||||
|
status: Execution status
|
||||||
|
execution_time_ms: Execution time in milliseconds
|
||||||
|
parameters: Parameters that were passed to the tool
|
||||||
|
server_name: MCP server name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolContext instance
|
||||||
|
"""
|
||||||
|
# Convert result to string content
|
||||||
|
if isinstance(result, str):
|
||||||
|
content = result
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
import json
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = json.dumps(result, indent=2)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
content = str(result)
|
||||||
|
else:
|
||||||
|
content = str(result)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
content=content,
|
||||||
|
source=f"tool_result:{server_name}:{tool_name}"
|
||||||
|
if server_name
|
||||||
|
else f"tool_result:{tool_name}",
|
||||||
|
tool_name=tool_name,
|
||||||
|
is_result=True,
|
||||||
|
result_status=status,
|
||||||
|
execution_time_ms=execution_time_ms,
|
||||||
|
parameters=parameters or {},
|
||||||
|
server_name=server_name,
|
||||||
|
priority=ContextPriority.HIGH.value, # Recent results are high priority
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_successful(self) -> bool:
|
||||||
|
"""Check if this is a successful tool result."""
|
||||||
|
return self.is_result and self.result_status == ToolResultStatus.SUCCESS
|
||||||
|
|
||||||
|
def is_error(self) -> bool:
|
||||||
|
"""Check if this is an error result."""
|
||||||
|
return self.is_result and self.result_status == ToolResultStatus.ERROR
|
||||||
|
|
||||||
|
def format_for_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Format tool context for inclusion in prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted tool string
|
||||||
|
"""
|
||||||
|
if self.is_result:
|
||||||
|
status_str = self.result_status.value if self.result_status else "unknown"
|
||||||
|
header = f"Tool Result ({self.tool_name}, {status_str}):"
|
||||||
|
return f"{header}\n{self.content}"
|
||||||
|
else:
|
||||||
|
return self.content
|
||||||
@@ -24,6 +24,9 @@ from ..models import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Sentinel for distinguishing "no argument passed" from "explicitly passing None"
|
||||||
|
_UNSET = object()
|
||||||
|
|
||||||
|
|
||||||
class AuditLogger:
|
class AuditLogger:
|
||||||
"""
|
"""
|
||||||
@@ -142,8 +145,10 @@ class AuditLogger:
|
|||||||
# Add hash chain for tamper detection
|
# Add hash chain for tamper detection
|
||||||
if self._enable_hash_chain:
|
if self._enable_hash_chain:
|
||||||
event_hash = self._compute_hash(event)
|
event_hash = self._compute_hash(event)
|
||||||
sanitized_details["_hash"] = event_hash
|
# Modify event.details directly (not sanitized_details)
|
||||||
sanitized_details["_prev_hash"] = self._last_hash
|
# to ensure the hash is stored on the actual event
|
||||||
|
event.details["_hash"] = event_hash
|
||||||
|
event.details["_prev_hash"] = self._last_hash
|
||||||
self._last_hash = event_hash
|
self._last_hash = event_hash
|
||||||
|
|
||||||
self._buffer.append(event)
|
self._buffer.append(event)
|
||||||
@@ -415,7 +420,8 @@ class AuditLogger:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if stored_hash:
|
if stored_hash:
|
||||||
computed = self._compute_hash(event)
|
# Pass prev_hash to compute hash with correct chain position
|
||||||
|
computed = self._compute_hash(event, prev_hash=prev_hash)
|
||||||
if computed != stored_hash:
|
if computed != stored_hash:
|
||||||
issues.append(
|
issues.append(
|
||||||
f"Hash mismatch at event {event.id}: "
|
f"Hash mismatch at event {event.id}: "
|
||||||
@@ -462,9 +468,23 @@ class AuditLogger:
|
|||||||
|
|
||||||
return sanitized
|
return sanitized
|
||||||
|
|
||||||
def _compute_hash(self, event: AuditEvent) -> str:
|
def _compute_hash(
|
||||||
"""Compute hash for an event (excluding hash fields)."""
|
self, event: AuditEvent, prev_hash: str | None | object = _UNSET
|
||||||
data = {
|
) -> str:
|
||||||
|
"""Compute hash for an event (excluding hash fields).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: The audit event to hash.
|
||||||
|
prev_hash: Optional previous hash to use instead of self._last_hash.
|
||||||
|
Pass this during verification to use the correct chain.
|
||||||
|
Use None explicitly to indicate no previous hash.
|
||||||
|
"""
|
||||||
|
# Use passed prev_hash if explicitly provided, otherwise use instance state
|
||||||
|
effective_prev: str | None = (
|
||||||
|
self._last_hash if prev_hash is _UNSET else prev_hash # type: ignore[assignment]
|
||||||
|
)
|
||||||
|
|
||||||
|
data: dict[str, str | dict[str, str] | None] = {
|
||||||
"id": event.id,
|
"id": event.id,
|
||||||
"event_type": event.event_type.value,
|
"event_type": event.event_type.value,
|
||||||
"timestamp": event.timestamp.isoformat(),
|
"timestamp": event.timestamp.isoformat(),
|
||||||
@@ -480,8 +500,8 @@ class AuditLogger:
|
|||||||
"correlation_id": event.correlation_id,
|
"correlation_id": event.correlation_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self._last_hash:
|
if effective_prev:
|
||||||
data["_prev_hash"] = self._last_hash
|
data["_prev_hash"] = effective_prev
|
||||||
|
|
||||||
serialized = json.dumps(data, sort_keys=True, default=str)
|
serialized = json.dumps(data, sort_keys=True, default=str)
|
||||||
return hashlib.sha256(serialized.encode()).hexdigest()
|
return hashlib.sha256(serialized.encode()).hexdigest()
|
||||||
|
|||||||
1
backend/tests/services/context/__init__.py
Normal file
1
backend/tests/services/context/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for Context Management Engine."""
|
||||||
518
backend/tests/services/context/test_adapters.py
Normal file
518
backend/tests/services/context/test_adapters.py
Normal file
@@ -0,0 +1,518 @@
|
|||||||
|
"""Tests for model adapters."""
|
||||||
|
|
||||||
|
from app.services.context.adapters import (
|
||||||
|
ClaudeAdapter,
|
||||||
|
DefaultAdapter,
|
||||||
|
OpenAIAdapter,
|
||||||
|
get_adapter,
|
||||||
|
)
|
||||||
|
from app.services.context.types import (
|
||||||
|
ContextType,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
ToolContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetAdapter:
|
||||||
|
"""Tests for get_adapter function."""
|
||||||
|
|
||||||
|
def test_claude_models(self) -> None:
|
||||||
|
"""Test that Claude models get ClaudeAdapter."""
|
||||||
|
assert isinstance(get_adapter("claude-3-sonnet"), ClaudeAdapter)
|
||||||
|
assert isinstance(get_adapter("claude-3-opus"), ClaudeAdapter)
|
||||||
|
assert isinstance(get_adapter("claude-3-haiku"), ClaudeAdapter)
|
||||||
|
assert isinstance(get_adapter("claude-2"), ClaudeAdapter)
|
||||||
|
assert isinstance(get_adapter("anthropic/claude-3-sonnet"), ClaudeAdapter)
|
||||||
|
|
||||||
|
def test_openai_models(self) -> None:
|
||||||
|
"""Test that OpenAI models get OpenAIAdapter."""
|
||||||
|
assert isinstance(get_adapter("gpt-4"), OpenAIAdapter)
|
||||||
|
assert isinstance(get_adapter("gpt-4-turbo"), OpenAIAdapter)
|
||||||
|
assert isinstance(get_adapter("gpt-3.5-turbo"), OpenAIAdapter)
|
||||||
|
assert isinstance(get_adapter("openai/gpt-4"), OpenAIAdapter)
|
||||||
|
assert isinstance(get_adapter("o1-mini"), OpenAIAdapter)
|
||||||
|
assert isinstance(get_adapter("o3-mini"), OpenAIAdapter)
|
||||||
|
|
||||||
|
def test_unknown_models(self) -> None:
|
||||||
|
"""Test that unknown models get DefaultAdapter."""
|
||||||
|
assert isinstance(get_adapter("llama-2"), DefaultAdapter)
|
||||||
|
assert isinstance(get_adapter("mistral-7b"), DefaultAdapter)
|
||||||
|
assert isinstance(get_adapter("custom-model"), DefaultAdapter)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelAdapterBase:
|
||||||
|
"""Tests for ModelAdapter base class."""
|
||||||
|
|
||||||
|
def test_get_type_order(self) -> None:
|
||||||
|
"""Test default type ordering."""
|
||||||
|
adapter = DefaultAdapter()
|
||||||
|
order = adapter.get_type_order()
|
||||||
|
|
||||||
|
assert order == [
|
||||||
|
ContextType.SYSTEM,
|
||||||
|
ContextType.TASK,
|
||||||
|
ContextType.KNOWLEDGE,
|
||||||
|
ContextType.CONVERSATION,
|
||||||
|
ContextType.TOOL,
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_group_by_type(self) -> None:
|
||||||
|
"""Test grouping contexts by type."""
|
||||||
|
adapter = DefaultAdapter()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
KnowledgeContext(content="Knowledge", source="docs"),
|
||||||
|
SystemContext(content="System 2", source="system"),
|
||||||
|
]
|
||||||
|
|
||||||
|
grouped = adapter.group_by_type(contexts)
|
||||||
|
|
||||||
|
assert len(grouped[ContextType.SYSTEM]) == 2
|
||||||
|
assert len(grouped[ContextType.TASK]) == 1
|
||||||
|
assert len(grouped[ContextType.KNOWLEDGE]) == 1
|
||||||
|
assert ContextType.CONVERSATION not in grouped
|
||||||
|
|
||||||
|
def test_matches_model_default(self) -> None:
|
||||||
|
"""Test that DefaultAdapter matches all models."""
|
||||||
|
assert DefaultAdapter.matches_model("anything")
|
||||||
|
assert DefaultAdapter.matches_model("claude-3")
|
||||||
|
assert DefaultAdapter.matches_model("gpt-4")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultAdapter:
|
||||||
|
"""Tests for DefaultAdapter."""
|
||||||
|
|
||||||
|
def test_format_empty(self) -> None:
|
||||||
|
"""Test formatting empty context list."""
|
||||||
|
adapter = DefaultAdapter()
|
||||||
|
result = adapter.format([])
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_format_system(self) -> None:
|
||||||
|
"""Test formatting system context."""
|
||||||
|
adapter = DefaultAdapter()
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="You are helpful.", source="system"),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "You are helpful." in result
|
||||||
|
|
||||||
|
def test_format_task(self) -> None:
|
||||||
|
"""Test formatting task context."""
|
||||||
|
adapter = DefaultAdapter()
|
||||||
|
contexts = [
|
||||||
|
TaskContext(content="Write a function.", source="task"),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "Task:" in result
|
||||||
|
assert "Write a function." in result
|
||||||
|
|
||||||
|
def test_format_knowledge(self) -> None:
|
||||||
|
"""Test formatting knowledge context."""
|
||||||
|
adapter = DefaultAdapter()
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="Documentation here.", source="docs"),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "Reference Information:" in result
|
||||||
|
assert "Documentation here." in result
|
||||||
|
|
||||||
|
def test_format_conversation(self) -> None:
|
||||||
|
"""Test formatting conversation context."""
|
||||||
|
adapter = DefaultAdapter()
|
||||||
|
contexts = [
|
||||||
|
ConversationContext(
|
||||||
|
content="Hello!",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "Previous Conversation:" in result
|
||||||
|
assert "Hello!" in result
|
||||||
|
|
||||||
|
def test_format_tool(self) -> None:
|
||||||
|
"""Test formatting tool context."""
|
||||||
|
adapter = DefaultAdapter()
|
||||||
|
contexts = [
|
||||||
|
ToolContext(
|
||||||
|
content="Result: success",
|
||||||
|
source="tool",
|
||||||
|
metadata={"tool_name": "search"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "Tool Results:" in result
|
||||||
|
assert "Result: success" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestClaudeAdapter:
|
||||||
|
"""Tests for ClaudeAdapter."""
|
||||||
|
|
||||||
|
def test_matches_model(self) -> None:
|
||||||
|
"""Test model matching."""
|
||||||
|
assert ClaudeAdapter.matches_model("claude-3-sonnet")
|
||||||
|
assert ClaudeAdapter.matches_model("claude-3-opus")
|
||||||
|
assert ClaudeAdapter.matches_model("anthropic/claude-3-haiku")
|
||||||
|
assert not ClaudeAdapter.matches_model("gpt-4")
|
||||||
|
assert not ClaudeAdapter.matches_model("llama-2")
|
||||||
|
|
||||||
|
def test_format_empty(self) -> None:
|
||||||
|
"""Test formatting empty context list."""
|
||||||
|
adapter = ClaudeAdapter()
|
||||||
|
result = adapter.format([])
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_format_system_uses_xml(self) -> None:
|
||||||
|
"""Test that system context uses XML tags."""
|
||||||
|
adapter = ClaudeAdapter()
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="You are helpful.", source="system"),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "<system_instructions>" in result
|
||||||
|
assert "</system_instructions>" in result
|
||||||
|
assert "You are helpful." in result
|
||||||
|
|
||||||
|
def test_format_task_uses_xml(self) -> None:
|
||||||
|
"""Test that task context uses XML tags."""
|
||||||
|
adapter = ClaudeAdapter()
|
||||||
|
contexts = [
|
||||||
|
TaskContext(content="Write a function.", source="task"),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "<current_task>" in result
|
||||||
|
assert "</current_task>" in result
|
||||||
|
assert "Write a function." in result
|
||||||
|
|
||||||
|
def test_format_knowledge_uses_document_tags(self) -> None:
|
||||||
|
"""Test that knowledge uses document XML tags."""
|
||||||
|
adapter = ClaudeAdapter()
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Documentation here.",
|
||||||
|
source="docs/api.md",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "<reference_documents>" in result
|
||||||
|
assert "</reference_documents>" in result
|
||||||
|
assert '<document source="docs/api.md"' in result
|
||||||
|
assert "</document>" in result
|
||||||
|
assert "Documentation here." in result
|
||||||
|
|
||||||
|
def test_format_knowledge_with_score(self) -> None:
|
||||||
|
"""Test that knowledge includes relevance score."""
|
||||||
|
adapter = ClaudeAdapter()
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Doc content.",
|
||||||
|
source="docs/api.md",
|
||||||
|
metadata={"relevance_score": 0.95},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert 'relevance="0.95"' in result
|
||||||
|
|
||||||
|
def test_format_conversation_uses_message_tags(self) -> None:
|
||||||
|
"""Test that conversation uses message XML tags."""
|
||||||
|
adapter = ClaudeAdapter()
|
||||||
|
contexts = [
|
||||||
|
ConversationContext(
|
||||||
|
content="Hello!",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
metadata={"role": "user"},
|
||||||
|
),
|
||||||
|
ConversationContext(
|
||||||
|
content="Hi there!",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.ASSISTANT,
|
||||||
|
metadata={"role": "assistant"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "<conversation_history>" in result
|
||||||
|
assert "</conversation_history>" in result
|
||||||
|
assert '<message role="user">' in result
|
||||||
|
assert '<message role="assistant">' in result
|
||||||
|
assert "Hello!" in result
|
||||||
|
assert "Hi there!" in result
|
||||||
|
|
||||||
|
def test_format_tool_uses_tool_result_tags(self) -> None:
|
||||||
|
"""Test that tool results use tool_result XML tags."""
|
||||||
|
adapter = ClaudeAdapter()
|
||||||
|
contexts = [
|
||||||
|
ToolContext(
|
||||||
|
content='{"status": "ok"}',
|
||||||
|
source="tool",
|
||||||
|
metadata={"tool_name": "search", "status": "success"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "<tool_results>" in result
|
||||||
|
assert "</tool_results>" in result
|
||||||
|
assert '<tool_result name="search"' in result
|
||||||
|
assert 'status="success"' in result
|
||||||
|
assert "</tool_result>" in result
|
||||||
|
|
||||||
|
def test_format_multiple_types_in_order(self) -> None:
|
||||||
|
"""Test that multiple types are formatted in correct order."""
|
||||||
|
adapter = ClaudeAdapter()
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="Knowledge", source="docs"),
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
|
||||||
|
# Find positions
|
||||||
|
system_pos = result.find("<system_instructions>")
|
||||||
|
task_pos = result.find("<current_task>")
|
||||||
|
knowledge_pos = result.find("<reference_documents>")
|
||||||
|
|
||||||
|
# Verify order
|
||||||
|
assert system_pos < task_pos < knowledge_pos
|
||||||
|
|
||||||
|
def test_escape_xml_in_source(self) -> None:
|
||||||
|
"""Test that XML special chars are escaped in source."""
|
||||||
|
adapter = ClaudeAdapter()
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Doc content.",
|
||||||
|
source='path/with"quotes&stuff.md',
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert """ in result
|
||||||
|
assert "&" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIAdapter:
|
||||||
|
"""Tests for OpenAIAdapter."""
|
||||||
|
|
||||||
|
def test_matches_model(self) -> None:
|
||||||
|
"""Test model matching."""
|
||||||
|
assert OpenAIAdapter.matches_model("gpt-4")
|
||||||
|
assert OpenAIAdapter.matches_model("gpt-4-turbo")
|
||||||
|
assert OpenAIAdapter.matches_model("gpt-3.5-turbo")
|
||||||
|
assert OpenAIAdapter.matches_model("openai/gpt-4")
|
||||||
|
assert OpenAIAdapter.matches_model("o1-preview")
|
||||||
|
assert OpenAIAdapter.matches_model("o3-mini")
|
||||||
|
assert not OpenAIAdapter.matches_model("claude-3")
|
||||||
|
assert not OpenAIAdapter.matches_model("llama-2")
|
||||||
|
|
||||||
|
def test_format_empty(self) -> None:
|
||||||
|
"""Test formatting empty context list."""
|
||||||
|
adapter = OpenAIAdapter()
|
||||||
|
result = adapter.format([])
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_format_system_plain(self) -> None:
|
||||||
|
"""Test that system content is plain."""
|
||||||
|
adapter = OpenAIAdapter()
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="You are helpful.", source="system"),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
# System content should be plain without headers
|
||||||
|
assert "You are helpful." in result
|
||||||
|
assert "##" not in result # No markdown headers for system
|
||||||
|
|
||||||
|
def test_format_task_uses_markdown(self) -> None:
|
||||||
|
"""Test that task uses markdown headers."""
|
||||||
|
adapter = OpenAIAdapter()
|
||||||
|
contexts = [
|
||||||
|
TaskContext(content="Write a function.", source="task"),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "## Current Task" in result
|
||||||
|
assert "Write a function." in result
|
||||||
|
|
||||||
|
def test_format_knowledge_uses_markdown(self) -> None:
|
||||||
|
"""Test that knowledge uses markdown with source headers."""
|
||||||
|
adapter = OpenAIAdapter()
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Documentation here.",
|
||||||
|
source="docs/api.md",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "## Reference Documents" in result
|
||||||
|
assert "### Source: docs/api.md" in result
|
||||||
|
assert "Documentation here." in result
|
||||||
|
|
||||||
|
def test_format_knowledge_with_score(self) -> None:
|
||||||
|
"""Test that knowledge includes relevance score."""
|
||||||
|
adapter = OpenAIAdapter()
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Doc content.",
|
||||||
|
source="docs/api.md",
|
||||||
|
metadata={"relevance_score": 0.95},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "(relevance: 0.95)" in result
|
||||||
|
|
||||||
|
def test_format_conversation_uses_bold_roles(self) -> None:
|
||||||
|
"""Test that conversation uses bold role labels."""
|
||||||
|
adapter = OpenAIAdapter()
|
||||||
|
contexts = [
|
||||||
|
ConversationContext(
|
||||||
|
content="Hello!",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
metadata={"role": "user"},
|
||||||
|
),
|
||||||
|
ConversationContext(
|
||||||
|
content="Hi there!",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.ASSISTANT,
|
||||||
|
metadata={"role": "assistant"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "**USER**:" in result
|
||||||
|
assert "**ASSISTANT**:" in result
|
||||||
|
assert "Hello!" in result
|
||||||
|
assert "Hi there!" in result
|
||||||
|
|
||||||
|
def test_format_tool_uses_code_blocks(self) -> None:
|
||||||
|
"""Test that tool results use code blocks."""
|
||||||
|
adapter = OpenAIAdapter()
|
||||||
|
contexts = [
|
||||||
|
ToolContext(
|
||||||
|
content='{"status": "ok"}',
|
||||||
|
source="tool",
|
||||||
|
metadata={"tool_name": "search", "status": "success"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
assert "## Recent Tool Results" in result
|
||||||
|
assert "### Tool: search (success)" in result
|
||||||
|
assert "```" in result # Code block
|
||||||
|
assert '{"status": "ok"}' in result
|
||||||
|
|
||||||
|
def test_format_multiple_types_in_order(self) -> None:
|
||||||
|
"""Test that multiple types are formatted in correct order."""
|
||||||
|
adapter = OpenAIAdapter()
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="Knowledge", source="docs"),
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
]
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
|
||||||
|
# System comes first (no header), then task, then knowledge
|
||||||
|
system_pos = result.find("System")
|
||||||
|
task_pos = result.find("## Current Task")
|
||||||
|
knowledge_pos = result.find("## Reference Documents")
|
||||||
|
|
||||||
|
assert system_pos < task_pos < knowledge_pos
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdapterIntegration:
|
||||||
|
"""Integration tests for adapters."""
|
||||||
|
|
||||||
|
def test_full_context_formatting_claude(self) -> None:
|
||||||
|
"""Test formatting a full set of contexts for Claude."""
|
||||||
|
adapter = ClaudeAdapter()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="You are an expert Python developer.",
|
||||||
|
source="system",
|
||||||
|
),
|
||||||
|
TaskContext(
|
||||||
|
content="Implement user authentication.",
|
||||||
|
source="task:AUTH-123",
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="JWT tokens provide stateless authentication...",
|
||||||
|
source="docs/auth/jwt.md",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
ConversationContext(
|
||||||
|
content="Can you help me implement JWT auth?",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
metadata={"role": "user"},
|
||||||
|
),
|
||||||
|
ToolContext(
|
||||||
|
content='{"file": "auth.py", "status": "created"}',
|
||||||
|
source="tool",
|
||||||
|
metadata={"tool_name": "file_create"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
|
||||||
|
# Verify all sections present
|
||||||
|
assert "<system_instructions>" in result
|
||||||
|
assert "<current_task>" in result
|
||||||
|
assert "<reference_documents>" in result
|
||||||
|
assert "<conversation_history>" in result
|
||||||
|
assert "<tool_results>" in result
|
||||||
|
|
||||||
|
# Verify content
|
||||||
|
assert "expert Python developer" in result
|
||||||
|
assert "user authentication" in result
|
||||||
|
assert "JWT tokens" in result
|
||||||
|
assert "help me implement" in result
|
||||||
|
assert "file_create" in result
|
||||||
|
|
||||||
|
def test_full_context_formatting_openai(self) -> None:
|
||||||
|
"""Test formatting a full set of contexts for OpenAI."""
|
||||||
|
adapter = OpenAIAdapter()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="You are an expert Python developer.",
|
||||||
|
source="system",
|
||||||
|
),
|
||||||
|
TaskContext(
|
||||||
|
content="Implement user authentication.",
|
||||||
|
source="task:AUTH-123",
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="JWT tokens provide stateless authentication...",
|
||||||
|
source="docs/auth/jwt.md",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
ConversationContext(
|
||||||
|
content="Can you help me implement JWT auth?",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
metadata={"role": "user"},
|
||||||
|
),
|
||||||
|
ToolContext(
|
||||||
|
content='{"file": "auth.py", "status": "created"}',
|
||||||
|
source="tool",
|
||||||
|
metadata={"tool_name": "file_create"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = adapter.format(contexts)
|
||||||
|
|
||||||
|
# Verify all sections present
|
||||||
|
assert "## Current Task" in result
|
||||||
|
assert "## Reference Documents" in result
|
||||||
|
assert "## Recent Tool Results" in result
|
||||||
|
assert "**USER**:" in result
|
||||||
|
|
||||||
|
# Verify content
|
||||||
|
assert "expert Python developer" in result
|
||||||
|
assert "user authentication" in result
|
||||||
|
assert "JWT tokens" in result
|
||||||
|
assert "help me implement" in result
|
||||||
|
assert "file_create" in result
|
||||||
508
backend/tests/services/context/test_assembly.py
Normal file
508
backend/tests/services/context/test_assembly.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
"""Tests for context assembly pipeline."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.assembly import ContextPipeline, PipelineMetrics
|
||||||
|
from app.services.context.budget import TokenBudget
|
||||||
|
from app.services.context.types import (
|
||||||
|
AssembledContext,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
ToolContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineMetrics:
|
||||||
|
"""Tests for PipelineMetrics dataclass."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test metrics creation."""
|
||||||
|
metrics = PipelineMetrics()
|
||||||
|
|
||||||
|
assert metrics.total_contexts == 0
|
||||||
|
assert metrics.selected_contexts == 0
|
||||||
|
assert metrics.assembly_time_ms == 0.0
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""Test conversion to dictionary."""
|
||||||
|
metrics = PipelineMetrics(
|
||||||
|
total_contexts=10,
|
||||||
|
selected_contexts=8,
|
||||||
|
excluded_contexts=2,
|
||||||
|
total_tokens=500,
|
||||||
|
assembly_time_ms=25.5,
|
||||||
|
)
|
||||||
|
metrics.end_time = datetime.now(UTC)
|
||||||
|
|
||||||
|
data = metrics.to_dict()
|
||||||
|
|
||||||
|
assert data["total_contexts"] == 10
|
||||||
|
assert data["selected_contexts"] == 8
|
||||||
|
assert data["excluded_contexts"] == 2
|
||||||
|
assert data["total_tokens"] == 500
|
||||||
|
assert data["assembly_time_ms"] == 25.5
|
||||||
|
assert "start_time" in data
|
||||||
|
assert "end_time" in data
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextPipeline:
|
||||||
|
"""Tests for ContextPipeline."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test pipeline creation."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
assert pipeline._calculator is not None
|
||||||
|
assert pipeline._scorer is not None
|
||||||
|
assert pipeline._ranker is not None
|
||||||
|
assert pipeline._compressor is not None
|
||||||
|
assert pipeline._allocator is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_empty_contexts(self) -> None:
|
||||||
|
"""Test assembling empty context list."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=[],
|
||||||
|
query="test query",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, AssembledContext)
|
||||||
|
assert result.context_count == 0
|
||||||
|
assert result.total_tokens == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_single_context(self) -> None:
|
||||||
|
"""Test assembling single context."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="You are a helpful assistant.",
|
||||||
|
source="system",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="help me",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 1
|
||||||
|
assert result.total_tokens > 0
|
||||||
|
assert "helpful assistant" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_multiple_types(self) -> None:
|
||||||
|
"""Test assembling multiple context types."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="You are a coding assistant.",
|
||||||
|
source="system",
|
||||||
|
),
|
||||||
|
TaskContext(
|
||||||
|
content="Implement a login feature.",
|
||||||
|
source="task",
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Authentication best practices include...",
|
||||||
|
source="docs/auth.md",
|
||||||
|
relevance_score=0.8,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="implement login",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count >= 1
|
||||||
|
assert result.total_tokens > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_custom_budget(self) -> None:
|
||||||
|
"""Test assembling with custom budget."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=1000,
|
||||||
|
system=200,
|
||||||
|
task=200,
|
||||||
|
knowledge=400,
|
||||||
|
conversation=100,
|
||||||
|
tools=50,
|
||||||
|
response_reserve=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System prompt", source="system"),
|
||||||
|
TaskContext(content="Task description", source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="gpt-4",
|
||||||
|
custom_budget=budget,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_max_tokens(self) -> None:
|
||||||
|
"""Test assembling with max_tokens limit."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System prompt", source="system"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="gpt-4",
|
||||||
|
max_tokens=5000,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "budget" in result.metadata
|
||||||
|
assert result.metadata["budget"]["total"] == 5000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_format_output(self) -> None:
|
||||||
|
"""Test formatted vs unformatted output."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System prompt", source="system"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Formatted (default)
|
||||||
|
result_formatted = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
format_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Unformatted
|
||||||
|
result_raw = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
format_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Formatted should have XML tags for Claude
|
||||||
|
assert "<system_instructions>" in result_formatted.content
|
||||||
|
# Raw should not
|
||||||
|
assert "<system_instructions>" not in result_raw.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_metrics(self) -> None:
|
||||||
|
"""Test that metrics are populated."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Knowledge",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "metrics" in result.metadata
|
||||||
|
metrics = result.metadata["metrics"]
|
||||||
|
|
||||||
|
assert metrics["total_contexts"] == 3
|
||||||
|
assert metrics["assembly_time_ms"] > 0
|
||||||
|
assert "scoring_time_ms" in metrics
|
||||||
|
assert "formatting_time_ms" in metrics
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_compression_disabled(self) -> None:
|
||||||
|
"""Test assembling with compression disabled."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="A" * 1000, source="docs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="gpt-4",
|
||||||
|
compress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still work, just no compression applied
|
||||||
|
assert result.context_count >= 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextPipelineFormatting:
|
||||||
|
"""Tests for context formatting."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_claude_uses_xml(self) -> None:
|
||||||
|
"""Test that Claude models use XML formatting."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System prompt", source="system"),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Knowledge",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Claude should have XML tags
|
||||||
|
assert "<system_instructions>" in result.content or result.context_count == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_openai_uses_markdown(self) -> None:
|
||||||
|
"""Test that OpenAI models use markdown formatting."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
TaskContext(content="Task description", source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="gpt-4",
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenAI should have markdown headers
|
||||||
|
if result.context_count > 0 and "Task" in result.content:
|
||||||
|
assert "## Current Task" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_knowledge_claude(self) -> None:
|
||||||
|
"""Test knowledge formatting for Claude."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Document content here",
|
||||||
|
source="docs/file.md",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.context_count > 0:
|
||||||
|
assert "<reference_documents>" in result.content
|
||||||
|
assert "<document" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_conversation(self) -> None:
|
||||||
|
"""Test conversation formatting."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
ConversationContext(
|
||||||
|
content="Hello, how are you?",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
metadata={"role": "user"},
|
||||||
|
),
|
||||||
|
ConversationContext(
|
||||||
|
content="I'm doing great!",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.ASSISTANT,
|
||||||
|
metadata={"role": "assistant"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.context_count > 0:
|
||||||
|
assert "<conversation_history>" in result.content
|
||||||
|
assert (
|
||||||
|
'<message role="user">' in result.content
|
||||||
|
or 'role="user"' in result.content
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_format_tool_results(self) -> None:
|
||||||
|
"""Test tool result formatting."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
ToolContext(
|
||||||
|
content="Tool output here",
|
||||||
|
source="tool",
|
||||||
|
metadata={"tool_name": "search"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.context_count > 0:
|
||||||
|
assert "<tool_results>" in result.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextPipelineIntegration:
|
||||||
|
"""Integration tests for full pipeline."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_pipeline_workflow(self) -> None:
|
||||||
|
"""Test complete pipeline workflow."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
# Create realistic context mix
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="You are an expert Python developer.",
|
||||||
|
source="system",
|
||||||
|
),
|
||||||
|
TaskContext(
|
||||||
|
content="Implement a user authentication system.",
|
||||||
|
source="task:AUTH-123",
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="JWT tokens provide stateless authentication...",
|
||||||
|
source="docs/auth/jwt.md",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="OAuth 2.0 is an authorization framework...",
|
||||||
|
source="docs/auth/oauth.md",
|
||||||
|
relevance_score=0.7,
|
||||||
|
),
|
||||||
|
ConversationContext(
|
||||||
|
content="Can you help me implement JWT auth?",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
metadata={"role": "user"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="implement JWT authentication",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify result
|
||||||
|
assert isinstance(result, AssembledContext)
|
||||||
|
assert result.context_count > 0
|
||||||
|
assert result.total_tokens > 0
|
||||||
|
assert result.assembly_time_ms > 0
|
||||||
|
assert result.model == "claude-3-sonnet"
|
||||||
|
assert len(result.content) > 0
|
||||||
|
|
||||||
|
# Verify metrics
|
||||||
|
assert "metrics" in result.metadata
|
||||||
|
assert "query" in result.metadata
|
||||||
|
assert "budget" in result.metadata
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_type_ordering(self) -> None:
|
||||||
|
"""Test that contexts are ordered by type correctly."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
# Add in random order
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="Knowledge", source="docs", relevance_score=0.9),
|
||||||
|
ToolContext(content="Tool", source="tool", metadata={"tool_name": "test"}),
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
ConversationContext(
|
||||||
|
content="Chat",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
metadata={"role": "user"},
|
||||||
|
),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
)
|
||||||
|
|
||||||
|
# For Claude, verify order: System -> Task -> Knowledge -> Conversation -> Tool
|
||||||
|
content = result.content
|
||||||
|
if result.context_count > 0:
|
||||||
|
# Find positions (if they exist)
|
||||||
|
system_pos = content.find("system_instructions")
|
||||||
|
task_pos = content.find("current_task")
|
||||||
|
knowledge_pos = content.find("reference_documents")
|
||||||
|
conversation_pos = content.find("conversation_history")
|
||||||
|
tool_pos = content.find("tool_results")
|
||||||
|
|
||||||
|
# Verify ordering (only check if both exist)
|
||||||
|
if system_pos >= 0 and task_pos >= 0:
|
||||||
|
assert system_pos < task_pos
|
||||||
|
if task_pos >= 0 and knowledge_pos >= 0:
|
||||||
|
assert task_pos < knowledge_pos
|
||||||
|
if knowledge_pos >= 0 and conversation_pos >= 0:
|
||||||
|
assert knowledge_pos < conversation_pos
|
||||||
|
if conversation_pos >= 0 and tool_pos >= 0:
|
||||||
|
assert conversation_pos < tool_pos
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_excluded_contexts_tracked(self) -> None:
|
||||||
|
"""Test that excluded contexts are tracked in result."""
|
||||||
|
pipeline = ContextPipeline()
|
||||||
|
|
||||||
|
# Create many contexts to force some exclusions
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="A" * 500, # Large content
|
||||||
|
source=f"docs/{i}",
|
||||||
|
relevance_score=0.1 + (i * 0.05),
|
||||||
|
)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await pipeline.assemble(
|
||||||
|
contexts=contexts,
|
||||||
|
query="test",
|
||||||
|
model="gpt-4", # Smaller context window
|
||||||
|
max_tokens=1000, # Limited budget
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have excluded some
|
||||||
|
assert result.excluded_count >= 0
|
||||||
|
assert result.context_count + result.excluded_count <= len(contexts)
|
||||||
533
backend/tests/services/context/test_budget.py
Normal file
533
backend/tests/services/context/test_budget.py
Normal file
@@ -0,0 +1,533 @@
|
|||||||
|
"""Tests for token budget management."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.budget import (
|
||||||
|
BudgetAllocator,
|
||||||
|
TokenBudget,
|
||||||
|
TokenCalculator,
|
||||||
|
)
|
||||||
|
from app.services.context.config import ContextSettings
|
||||||
|
from app.services.context.exceptions import BudgetExceededError
|
||||||
|
from app.services.context.types import ContextType
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenBudget:
|
||||||
|
"""Tests for TokenBudget dataclass."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test basic budget creation."""
|
||||||
|
budget = TokenBudget(total=10000)
|
||||||
|
assert budget.total == 10000
|
||||||
|
assert budget.system == 0
|
||||||
|
assert budget.total_used() == 0
|
||||||
|
|
||||||
|
def test_creation_with_allocations(self) -> None:
|
||||||
|
"""Test budget creation with allocations."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
task=1000,
|
||||||
|
knowledge=4000,
|
||||||
|
conversation=2000,
|
||||||
|
tools=500,
|
||||||
|
response_reserve=1500,
|
||||||
|
buffer=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert budget.system == 500
|
||||||
|
assert budget.knowledge == 4000
|
||||||
|
assert budget.response_reserve == 1500
|
||||||
|
|
||||||
|
def test_get_allocation(self) -> None:
|
||||||
|
"""Test getting allocation for a type."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
knowledge=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert budget.get_allocation(ContextType.SYSTEM) == 500
|
||||||
|
assert budget.get_allocation(ContextType.KNOWLEDGE) == 4000
|
||||||
|
assert budget.get_allocation("system") == 500
|
||||||
|
|
||||||
|
def test_remaining(self) -> None:
|
||||||
|
"""Test remaining budget calculation."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
knowledge=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initially full
|
||||||
|
assert budget.remaining(ContextType.SYSTEM) == 500
|
||||||
|
assert budget.remaining(ContextType.KNOWLEDGE) == 4000
|
||||||
|
|
||||||
|
# After allocation
|
||||||
|
budget.allocate(ContextType.SYSTEM, 200)
|
||||||
|
assert budget.remaining(ContextType.SYSTEM) == 300
|
||||||
|
|
||||||
|
def test_can_fit(self) -> None:
|
||||||
|
"""Test can_fit check."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
knowledge=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert budget.can_fit(ContextType.SYSTEM, 500) is True
|
||||||
|
assert budget.can_fit(ContextType.SYSTEM, 501) is False
|
||||||
|
assert budget.can_fit(ContextType.KNOWLEDGE, 4000) is True
|
||||||
|
|
||||||
|
def test_allocate_success(self) -> None:
|
||||||
|
"""Test successful allocation."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = budget.allocate(ContextType.SYSTEM, 200)
|
||||||
|
assert result is True
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 200
|
||||||
|
assert budget.remaining(ContextType.SYSTEM) == 300
|
||||||
|
|
||||||
|
def test_allocate_exceeds_budget(self) -> None:
|
||||||
|
"""Test allocation exceeding budget."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(BudgetExceededError) as exc_info:
|
||||||
|
budget.allocate(ContextType.SYSTEM, 600)
|
||||||
|
|
||||||
|
assert exc_info.value.allocated == 500
|
||||||
|
assert exc_info.value.requested == 600
|
||||||
|
|
||||||
|
def test_allocate_force(self) -> None:
|
||||||
|
"""Test forced allocation exceeding budget."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Force should allow exceeding
|
||||||
|
result = budget.allocate(ContextType.SYSTEM, 600, force=True)
|
||||||
|
assert result is True
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 600
|
||||||
|
|
||||||
|
def test_deallocate(self) -> None:
|
||||||
|
"""Test deallocation."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
budget.allocate(ContextType.SYSTEM, 300)
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 300
|
||||||
|
|
||||||
|
budget.deallocate(ContextType.SYSTEM, 100)
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 200
|
||||||
|
|
||||||
|
def test_deallocate_below_zero(self) -> None:
|
||||||
|
"""Test deallocation doesn't go below zero."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
budget.allocate(ContextType.SYSTEM, 100)
|
||||||
|
budget.deallocate(ContextType.SYSTEM, 200)
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 0
|
||||||
|
|
||||||
|
def test_total_remaining(self) -> None:
|
||||||
|
"""Test total remaining calculation."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
knowledge=4000,
|
||||||
|
response_reserve=1500,
|
||||||
|
buffer=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Usable = total - response_reserve - buffer = 10000 - 1500 - 500 = 8000
|
||||||
|
assert budget.total_remaining() == 8000
|
||||||
|
|
||||||
|
# After allocation
|
||||||
|
budget.allocate(ContextType.SYSTEM, 200)
|
||||||
|
assert budget.total_remaining() == 7800
|
||||||
|
|
||||||
|
def test_utilization(self) -> None:
|
||||||
|
"""Test utilization calculation."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
response_reserve=1500,
|
||||||
|
buffer=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No usage = 0%
|
||||||
|
assert budget.utilization(ContextType.SYSTEM) == 0.0
|
||||||
|
|
||||||
|
# Half used = 50%
|
||||||
|
budget.allocate(ContextType.SYSTEM, 250)
|
||||||
|
assert budget.utilization(ContextType.SYSTEM) == 0.5
|
||||||
|
|
||||||
|
# Total utilization
|
||||||
|
assert budget.utilization() == 250 / 8000 # 250 / (10000 - 1500 - 500)
|
||||||
|
|
||||||
|
def test_reset(self) -> None:
|
||||||
|
"""Test reset clears usage."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
budget.allocate(ContextType.SYSTEM, 300)
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 300
|
||||||
|
|
||||||
|
budget.reset()
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == 0
|
||||||
|
assert budget.total_used() == 0
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""Test to_dict conversion."""
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=10000,
|
||||||
|
system=500,
|
||||||
|
task=1000,
|
||||||
|
knowledge=4000,
|
||||||
|
)
|
||||||
|
|
||||||
|
budget.allocate(ContextType.SYSTEM, 200)
|
||||||
|
|
||||||
|
data = budget.to_dict()
|
||||||
|
assert data["total"] == 10000
|
||||||
|
assert data["allocations"]["system"] == 500
|
||||||
|
assert data["used"]["system"] == 200
|
||||||
|
assert data["remaining"]["system"] == 300
|
||||||
|
|
||||||
|
|
||||||
|
class TestBudgetAllocator:
|
||||||
|
"""Tests for BudgetAllocator."""
|
||||||
|
|
||||||
|
def test_create_budget(self) -> None:
|
||||||
|
"""Test budget creation with default allocations."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(100000)
|
||||||
|
|
||||||
|
assert budget.total == 100000
|
||||||
|
assert budget.system == 5000 # 5%
|
||||||
|
assert budget.task == 10000 # 10%
|
||||||
|
assert budget.knowledge == 40000 # 40%
|
||||||
|
assert budget.conversation == 20000 # 20%
|
||||||
|
assert budget.tools == 5000 # 5%
|
||||||
|
assert budget.response_reserve == 15000 # 15%
|
||||||
|
assert budget.buffer == 5000 # 5%
|
||||||
|
|
||||||
|
def test_create_budget_custom_allocations(self) -> None:
|
||||||
|
"""Test budget creation with custom allocations."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(
|
||||||
|
100000,
|
||||||
|
custom_allocations={
|
||||||
|
"system": 0.10,
|
||||||
|
"task": 0.10,
|
||||||
|
"knowledge": 0.30,
|
||||||
|
"conversation": 0.25,
|
||||||
|
"tools": 0.05,
|
||||||
|
"response": 0.15,
|
||||||
|
"buffer": 0.05,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert budget.system == 10000 # 10%
|
||||||
|
assert budget.knowledge == 30000 # 30%
|
||||||
|
|
||||||
|
def test_create_budget_for_model(self) -> None:
|
||||||
|
"""Test budget creation for specific model."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
|
||||||
|
# Claude models have 200k context
|
||||||
|
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||||
|
assert budget.total == 200000
|
||||||
|
|
||||||
|
# GPT-4 has 8k context
|
||||||
|
budget = allocator.create_budget_for_model("gpt-4")
|
||||||
|
assert budget.total == 8192
|
||||||
|
|
||||||
|
# GPT-4-turbo has 128k context
|
||||||
|
budget = allocator.create_budget_for_model("gpt-4-turbo")
|
||||||
|
assert budget.total == 128000
|
||||||
|
|
||||||
|
def test_get_model_context_size(self) -> None:
|
||||||
|
"""Test model context size lookup."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
|
||||||
|
# Known models
|
||||||
|
assert allocator.get_model_context_size("claude-3-opus") == 200000
|
||||||
|
assert allocator.get_model_context_size("gpt-4") == 8192
|
||||||
|
assert allocator.get_model_context_size("gemini-1.5-pro") == 2000000
|
||||||
|
|
||||||
|
# Unknown model gets default
|
||||||
|
assert allocator.get_model_context_size("unknown-model") == 8192
|
||||||
|
|
||||||
|
def test_adjust_budget(self) -> None:
|
||||||
|
"""Test budget adjustment."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
original_system = budget.system
|
||||||
|
original_buffer = budget.buffer
|
||||||
|
|
||||||
|
# Increase system by taking from buffer
|
||||||
|
budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 200)
|
||||||
|
|
||||||
|
assert budget.system == original_system + 200
|
||||||
|
assert budget.buffer == original_buffer - 200
|
||||||
|
|
||||||
|
def test_adjust_budget_limited_by_buffer(self) -> None:
|
||||||
|
"""Test that adjustment is limited by buffer size."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
original_buffer = budget.buffer
|
||||||
|
|
||||||
|
# Try to increase more than buffer allows
|
||||||
|
budget = allocator.adjust_budget(budget, ContextType.SYSTEM, 10000)
|
||||||
|
|
||||||
|
# Should only increase by buffer amount
|
||||||
|
assert budget.buffer == 0
|
||||||
|
assert budget.system <= original_buffer + budget.system
|
||||||
|
|
||||||
|
def test_rebalance_budget(self) -> None:
|
||||||
|
"""Test budget rebalancing."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
# Use most of knowledge budget
|
||||||
|
budget.allocate(ContextType.KNOWLEDGE, 3500)
|
||||||
|
|
||||||
|
# Rebalance prioritizing knowledge
|
||||||
|
budget = allocator.rebalance_budget(
|
||||||
|
budget,
|
||||||
|
prioritize=[ContextType.KNOWLEDGE],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Knowledge should have gotten more tokens
|
||||||
|
# (This is a fuzzy test - just check it runs)
|
||||||
|
assert budget is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenCalculator:
|
||||||
|
"""Tests for TokenCalculator."""
|
||||||
|
|
||||||
|
def test_estimate_tokens(self) -> None:
|
||||||
|
"""Test token estimation."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
# Empty string
|
||||||
|
assert calc.estimate_tokens("") == 0
|
||||||
|
|
||||||
|
# Short text (~4 chars per token)
|
||||||
|
text = "This is a test message"
|
||||||
|
estimate = calc.estimate_tokens(text)
|
||||||
|
assert 4 <= estimate <= 8
|
||||||
|
|
||||||
|
def test_estimate_tokens_model_specific(self) -> None:
|
||||||
|
"""Test model-specific estimation ratios."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
text = "a" * 100
|
||||||
|
|
||||||
|
# Claude uses 3.5 chars per token
|
||||||
|
claude_estimate = calc.estimate_tokens(text, "claude-3-sonnet")
|
||||||
|
# GPT uses 4.0 chars per token
|
||||||
|
gpt_estimate = calc.estimate_tokens(text, "gpt-4")
|
||||||
|
|
||||||
|
# Claude should estimate more tokens (smaller ratio)
|
||||||
|
assert claude_estimate >= gpt_estimate
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens_no_mcp(self) -> None:
|
||||||
|
"""Test token counting without MCP (fallback to estimation)."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
text = "This is a test"
|
||||||
|
count = await calc.count_tokens(text)
|
||||||
|
|
||||||
|
# Should use estimation
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens_with_mcp_success(self) -> None:
|
||||||
|
"""Test token counting with MCP integration."""
|
||||||
|
# Mock MCP manager
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.success = True
|
||||||
|
mock_result.data = {"token_count": 42}
|
||||||
|
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
calc = TokenCalculator(mcp_manager=mock_mcp)
|
||||||
|
count = await calc.count_tokens("test text")
|
||||||
|
|
||||||
|
assert count == 42
|
||||||
|
mock_mcp.call_tool.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens_with_mcp_failure(self) -> None:
|
||||||
|
"""Test fallback when MCP fails."""
|
||||||
|
# Mock MCP manager that fails
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed"))
|
||||||
|
|
||||||
|
calc = TokenCalculator(mcp_manager=mock_mcp)
|
||||||
|
count = await calc.count_tokens("test text")
|
||||||
|
|
||||||
|
# Should fall back to estimation
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens_caching(self) -> None:
|
||||||
|
"""Test that token counts are cached."""
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.success = True
|
||||||
|
mock_result.data = {"token_count": 42}
|
||||||
|
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
calc = TokenCalculator(mcp_manager=mock_mcp)
|
||||||
|
|
||||||
|
# First call
|
||||||
|
count1 = await calc.count_tokens("test text")
|
||||||
|
# Second call (should use cache)
|
||||||
|
count2 = await calc.count_tokens("test text")
|
||||||
|
|
||||||
|
assert count1 == count2 == 42
|
||||||
|
# MCP should only be called once
|
||||||
|
assert mock_mcp.call_tool.call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens_batch(self) -> None:
|
||||||
|
"""Test batch token counting."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
texts = ["Hello", "World", "Test message here"]
|
||||||
|
counts = await calc.count_tokens_batch(texts)
|
||||||
|
|
||||||
|
assert len(counts) == 3
|
||||||
|
assert all(c > 0 for c in counts)
|
||||||
|
|
||||||
|
def test_cache_stats(self) -> None:
|
||||||
|
"""Test cache statistics."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
stats = calc.get_cache_stats()
|
||||||
|
assert stats["enabled"] is True
|
||||||
|
assert stats["size"] == 0
|
||||||
|
assert stats["hits"] == 0
|
||||||
|
assert stats["misses"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit_rate(self) -> None:
|
||||||
|
"""Test cache hit rate tracking."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
# Make some calls
|
||||||
|
await calc.count_tokens("text1")
|
||||||
|
await calc.count_tokens("text2")
|
||||||
|
await calc.count_tokens("text1") # Cache hit
|
||||||
|
|
||||||
|
stats = calc.get_cache_stats()
|
||||||
|
assert stats["hits"] == 1
|
||||||
|
assert stats["misses"] == 2
|
||||||
|
|
||||||
|
def test_clear_cache(self) -> None:
|
||||||
|
"""Test cache clearing."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
calc._cache["test"] = 100
|
||||||
|
calc._cache_hits = 5
|
||||||
|
|
||||||
|
calc.clear_cache()
|
||||||
|
|
||||||
|
assert len(calc._cache) == 0
|
||||||
|
assert calc._cache_hits == 0
|
||||||
|
|
||||||
|
def test_set_mcp_manager(self) -> None:
|
||||||
|
"""Test setting MCP manager after initialization."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
assert calc._mcp is None
|
||||||
|
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
calc.set_mcp_manager(mock_mcp)
|
||||||
|
|
||||||
|
assert calc._mcp is mock_mcp
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_token_count_formats(self) -> None:
|
||||||
|
"""Test parsing different token count response formats."""
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
# Dict with token_count
|
||||||
|
assert calc._parse_token_count({"token_count": 42}) == 42
|
||||||
|
|
||||||
|
# Dict with tokens
|
||||||
|
assert calc._parse_token_count({"tokens": 42}) == 42
|
||||||
|
|
||||||
|
# Dict with count
|
||||||
|
assert calc._parse_token_count({"count": 42}) == 42
|
||||||
|
|
||||||
|
# Direct int
|
||||||
|
assert calc._parse_token_count(42) == 42
|
||||||
|
|
||||||
|
# JSON string
|
||||||
|
assert calc._parse_token_count('{"token_count": 42}') == 42
|
||||||
|
|
||||||
|
# Invalid
|
||||||
|
assert calc._parse_token_count("invalid") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestBudgetIntegration:
|
||||||
|
"""Integration tests for budget management."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_budget_workflow(self) -> None:
|
||||||
|
"""Test complete budget allocation workflow."""
|
||||||
|
# Create settings and allocator
|
||||||
|
settings = ContextSettings()
|
||||||
|
allocator = BudgetAllocator(settings)
|
||||||
|
|
||||||
|
# Create budget for Claude
|
||||||
|
budget = allocator.create_budget_for_model("claude-3-sonnet")
|
||||||
|
assert budget.total == 200000
|
||||||
|
|
||||||
|
# Create calculator (without MCP for test)
|
||||||
|
calc = TokenCalculator()
|
||||||
|
|
||||||
|
# Simulate context allocation
|
||||||
|
system_text = "You are a helpful assistant." * 10
|
||||||
|
system_tokens = await calc.count_tokens(system_text)
|
||||||
|
|
||||||
|
# Allocate
|
||||||
|
assert budget.can_fit(ContextType.SYSTEM, system_tokens)
|
||||||
|
budget.allocate(ContextType.SYSTEM, system_tokens)
|
||||||
|
|
||||||
|
# Check state
|
||||||
|
assert budget.get_used(ContextType.SYSTEM) == system_tokens
|
||||||
|
assert budget.remaining(ContextType.SYSTEM) == budget.system - system_tokens
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_budget_overflow_handling(self) -> None:
|
||||||
|
"""Test handling budget overflow."""
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(1000) # Small budget
|
||||||
|
|
||||||
|
# Try to allocate more than available
|
||||||
|
with pytest.raises(BudgetExceededError):
|
||||||
|
budget.allocate(ContextType.KNOWLEDGE, 500)
|
||||||
|
|
||||||
|
# Force allocation should work
|
||||||
|
budget.allocate(ContextType.KNOWLEDGE, 500, force=True)
|
||||||
|
assert budget.get_used(ContextType.KNOWLEDGE) == 500
|
||||||
479
backend/tests/services/context/test_cache.py
Normal file
479
backend/tests/services/context/test_cache.py
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
"""Tests for context cache module."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.cache import ContextCache
|
||||||
|
from app.services.context.config import ContextSettings
|
||||||
|
from app.services.context.exceptions import CacheError
|
||||||
|
from app.services.context.types import (
|
||||||
|
AssembledContext,
|
||||||
|
ContextPriority,
|
||||||
|
KnowledgeContext,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextCacheBasics:
|
||||||
|
"""Basic tests for ContextCache."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test cache creation without Redis."""
|
||||||
|
cache = ContextCache()
|
||||||
|
assert cache._redis is None
|
||||||
|
assert not cache.is_enabled
|
||||||
|
|
||||||
|
def test_creation_with_settings(self) -> None:
|
||||||
|
"""Test cache creation with custom settings."""
|
||||||
|
settings = ContextSettings(
|
||||||
|
cache_prefix="test",
|
||||||
|
cache_ttl_seconds=60,
|
||||||
|
)
|
||||||
|
cache = ContextCache(settings=settings)
|
||||||
|
assert cache._prefix == "test"
|
||||||
|
assert cache._ttl == 60
|
||||||
|
|
||||||
|
def test_set_redis(self) -> None:
|
||||||
|
"""Test setting Redis connection."""
|
||||||
|
cache = ContextCache()
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
cache.set_redis(mock_redis)
|
||||||
|
assert cache._redis is mock_redis
|
||||||
|
|
||||||
|
def test_is_enabled(self) -> None:
|
||||||
|
"""Test is_enabled property."""
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(settings=settings)
|
||||||
|
assert not cache.is_enabled # No Redis
|
||||||
|
|
||||||
|
cache.set_redis(MagicMock())
|
||||||
|
assert cache.is_enabled
|
||||||
|
|
||||||
|
# Disabled in settings
|
||||||
|
settings2 = ContextSettings(cache_enabled=False)
|
||||||
|
cache2 = ContextCache(redis=MagicMock(), settings=settings2)
|
||||||
|
assert not cache2.is_enabled
|
||||||
|
|
||||||
|
def test_cache_key(self) -> None:
|
||||||
|
"""Test cache key generation."""
|
||||||
|
cache = ContextCache()
|
||||||
|
key = cache._cache_key("assembled", "abc123")
|
||||||
|
assert key == "ctx:assembled:abc123"
|
||||||
|
|
||||||
|
def test_hash_content(self) -> None:
|
||||||
|
"""Test content hashing."""
|
||||||
|
hash1 = ContextCache._hash_content("hello world")
|
||||||
|
hash2 = ContextCache._hash_content("hello world")
|
||||||
|
hash3 = ContextCache._hash_content("different")
|
||||||
|
|
||||||
|
assert hash1 == hash2
|
||||||
|
assert hash1 != hash3
|
||||||
|
assert len(hash1) == 32
|
||||||
|
|
||||||
|
|
||||||
|
class TestFingerprintComputation:
|
||||||
|
"""Tests for fingerprint computation."""
|
||||||
|
|
||||||
|
def test_compute_fingerprint(self) -> None:
|
||||||
|
"""Test fingerprint computation."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
fp1 = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||||
|
fp2 = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||||
|
fp3 = cache.compute_fingerprint(contexts, "different", "claude-3")
|
||||||
|
|
||||||
|
assert fp1 == fp2 # Same inputs = same fingerprint
|
||||||
|
assert fp1 != fp3 # Different query = different fingerprint
|
||||||
|
assert len(fp1) == 32
|
||||||
|
|
||||||
|
def test_fingerprint_includes_priority(self) -> None:
|
||||||
|
"""Test that fingerprint changes with priority."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
# Use KnowledgeContext since SystemContext has __post_init__ that may override
|
||||||
|
ctx1 = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Knowledge",
|
||||||
|
source="docs",
|
||||||
|
priority=ContextPriority.NORMAL.value,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
ctx2 = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Knowledge",
|
||||||
|
source="docs",
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
fp1 = cache.compute_fingerprint(ctx1, "query", "claude-3")
|
||||||
|
fp2 = cache.compute_fingerprint(ctx2, "query", "claude-3")
|
||||||
|
|
||||||
|
assert fp1 != fp2
|
||||||
|
|
||||||
|
def test_fingerprint_includes_model(self) -> None:
|
||||||
|
"""Test that fingerprint changes with model."""
|
||||||
|
cache = ContextCache()
|
||||||
|
contexts = [SystemContext(content="System", source="system")]
|
||||||
|
|
||||||
|
fp1 = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||||
|
fp2 = cache.compute_fingerprint(contexts, "query", "gpt-4")
|
||||||
|
|
||||||
|
assert fp1 != fp2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryCache:
|
||||||
|
"""Tests for in-memory caching."""
|
||||||
|
|
||||||
|
def test_memory_cache_fallback(self) -> None:
|
||||||
|
"""Test memory cache when Redis unavailable."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
# Should use memory cache
|
||||||
|
cache._set_memory("test-key", "42")
|
||||||
|
assert "test-key" in cache._memory_cache
|
||||||
|
assert cache._memory_cache["test-key"][0] == "42"
|
||||||
|
|
||||||
|
def test_memory_cache_eviction(self) -> None:
|
||||||
|
"""Test memory cache eviction."""
|
||||||
|
cache = ContextCache()
|
||||||
|
cache._max_memory_items = 10
|
||||||
|
|
||||||
|
# Fill cache
|
||||||
|
for i in range(15):
|
||||||
|
cache._set_memory(f"key-{i}", f"value-{i}")
|
||||||
|
|
||||||
|
# Should have evicted some items
|
||||||
|
assert len(cache._memory_cache) < 15
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssembledContextCache:
|
||||||
|
"""Tests for assembled context caching."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_assembled_no_redis(self) -> None:
|
||||||
|
"""Test get_assembled without Redis returns None."""
|
||||||
|
cache = ContextCache()
|
||||||
|
result = await cache.get_assembled("fingerprint")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_assembled_not_found(self) -> None:
|
||||||
|
"""Test get_assembled when key not found."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
result = await cache.get_assembled("fingerprint")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_assembled_found(self) -> None:
|
||||||
|
"""Test get_assembled when key found."""
|
||||||
|
# Create a context
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="Test content",
|
||||||
|
total_tokens=100,
|
||||||
|
context_count=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get.return_value = ctx.to_json()
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
result = await cache.get_assembled("fingerprint")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.content == "Test content"
|
||||||
|
assert result.total_tokens == 100
|
||||||
|
assert result.cache_hit is True
|
||||||
|
assert result.cache_key == "fingerprint"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_assembled(self) -> None:
|
||||||
|
"""Test set_assembled."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True, cache_ttl_seconds=60)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="Test content",
|
||||||
|
total_tokens=100,
|
||||||
|
context_count=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cache.set_assembled("fingerprint", ctx)
|
||||||
|
|
||||||
|
mock_redis.setex.assert_called_once()
|
||||||
|
call_args = mock_redis.setex.call_args
|
||||||
|
assert call_args[0][0] == "ctx:assembled:fingerprint"
|
||||||
|
assert call_args[0][1] == 60 # TTL
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_assembled_custom_ttl(self) -> None:
|
||||||
|
"""Test set_assembled with custom TTL."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="Test",
|
||||||
|
total_tokens=10,
|
||||||
|
context_count=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
await cache.set_assembled("fp", ctx, ttl=120)
|
||||||
|
|
||||||
|
call_args = mock_redis.setex.call_args
|
||||||
|
assert call_args[0][1] == 120
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_error_on_get(self) -> None:
|
||||||
|
"""Test CacheError raised on Redis error."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get.side_effect = Exception("Redis error")
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
with pytest.raises(CacheError):
|
||||||
|
await cache.get_assembled("fingerprint")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_error_on_set(self) -> None:
|
||||||
|
"""Test CacheError raised on Redis error."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.setex.side_effect = Exception("Redis error")
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="Test",
|
||||||
|
total_tokens=10,
|
||||||
|
context_count=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(CacheError):
|
||||||
|
await cache.set_assembled("fp", ctx)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenCountCache:
|
||||||
|
"""Tests for token count caching."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_token_count_memory_fallback(self) -> None:
|
||||||
|
"""Test get_token_count uses memory cache."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
# Set in memory
|
||||||
|
key = cache._cache_key("tokens", "default", cache._hash_content("hello"))
|
||||||
|
cache._set_memory(key, "42")
|
||||||
|
|
||||||
|
result = await cache.get_token_count("hello")
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_token_count_memory(self) -> None:
|
||||||
|
"""Test set_token_count stores in memory."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
await cache.set_token_count("hello", 42)
|
||||||
|
|
||||||
|
result = await cache.get_token_count("hello")
|
||||||
|
assert result == 42
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_token_count_with_model(self) -> None:
|
||||||
|
"""Test set_token_count with model-specific tokenization."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
await cache.set_token_count("hello", 42, model="claude-3")
|
||||||
|
await cache.set_token_count("hello", 50, model="gpt-4")
|
||||||
|
|
||||||
|
# Different models should have different keys
|
||||||
|
assert mock_redis.setex.call_count == 2
|
||||||
|
calls = mock_redis.setex.call_args_list
|
||||||
|
|
||||||
|
key1 = calls[0][0][0]
|
||||||
|
key2 = calls[1][0][0]
|
||||||
|
assert "claude-3" in key1
|
||||||
|
assert "gpt-4" in key2
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoreCache:
|
||||||
|
"""Tests for score caching."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_score_memory_fallback(self) -> None:
|
||||||
|
"""Test get_score uses memory cache."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
# Set in memory
|
||||||
|
query_hash = cache._hash_content("query")[:16]
|
||||||
|
key = cache._cache_key("score", "relevance", "ctx-123", query_hash)
|
||||||
|
cache._set_memory(key, "0.85")
|
||||||
|
|
||||||
|
result = await cache.get_score("relevance", "ctx-123", "query")
|
||||||
|
assert result == 0.85
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_score_memory(self) -> None:
|
||||||
|
"""Test set_score stores in memory."""
|
||||||
|
cache = ContextCache()
|
||||||
|
|
||||||
|
await cache.set_score("relevance", "ctx-123", "query", 0.85)
|
||||||
|
|
||||||
|
result = await cache.get_score("relevance", "ctx-123", "query")
|
||||||
|
assert result == 0.85
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_score_with_redis(self) -> None:
|
||||||
|
"""Test set_score with Redis."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
await cache.set_score("relevance", "ctx-123", "query", 0.85)
|
||||||
|
|
||||||
|
mock_redis.setex.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheInvalidation:
|
||||||
|
"""Tests for cache invalidation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalidate_pattern(self) -> None:
|
||||||
|
"""Test invalidate with pattern."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
# Set up scan_iter to return matching keys
|
||||||
|
async def mock_scan_iter(match=None):
|
||||||
|
for key in ["ctx:assembled:1", "ctx:assembled:2"]:
|
||||||
|
yield key
|
||||||
|
|
||||||
|
mock_redis.scan_iter = mock_scan_iter
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
deleted = await cache.invalidate("assembled:*")
|
||||||
|
|
||||||
|
assert deleted == 2
|
||||||
|
assert mock_redis.delete.call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_all(self) -> None:
|
||||||
|
"""Test clear_all."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_scan_iter(match=None):
|
||||||
|
for key in ["ctx:1", "ctx:2", "ctx:3"]:
|
||||||
|
yield key
|
||||||
|
|
||||||
|
mock_redis.scan_iter = mock_scan_iter
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
# Add to memory cache
|
||||||
|
cache._set_memory("test", "value")
|
||||||
|
assert len(cache._memory_cache) > 0
|
||||||
|
|
||||||
|
deleted = await cache.clear_all()
|
||||||
|
|
||||||
|
assert deleted == 3
|
||||||
|
assert len(cache._memory_cache) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheStats:
|
||||||
|
"""Tests for cache statistics."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats_no_redis(self) -> None:
|
||||||
|
"""Test get_stats without Redis."""
|
||||||
|
cache = ContextCache()
|
||||||
|
cache._set_memory("key", "value")
|
||||||
|
|
||||||
|
stats = await cache.get_stats()
|
||||||
|
|
||||||
|
assert stats["enabled"] is True
|
||||||
|
assert stats["redis_available"] is False
|
||||||
|
assert stats["memory_items"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats_with_redis(self) -> None:
|
||||||
|
"""Test get_stats with Redis."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.info.return_value = {"used_memory_human": "1.5M"}
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True, cache_ttl_seconds=300)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
stats = await cache.get_stats()
|
||||||
|
|
||||||
|
assert stats["enabled"] is True
|
||||||
|
assert stats["redis_available"] is True
|
||||||
|
assert stats["ttl_seconds"] == 300
|
||||||
|
assert stats["redis_memory_used"] == "1.5M"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheIntegration:
|
||||||
|
"""Integration tests for cache."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_workflow(self) -> None:
|
||||||
|
"""Test complete cache workflow."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
cache = ContextCache(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(content="System", source="system"),
|
||||||
|
KnowledgeContext(content="Knowledge", source="docs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Compute fingerprint
|
||||||
|
fp = cache.compute_fingerprint(contexts, "query", "claude-3")
|
||||||
|
assert len(fp) == 32
|
||||||
|
|
||||||
|
# Check cache (miss)
|
||||||
|
result = await cache.get_assembled(fp)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Create and cache assembled context
|
||||||
|
assembled = AssembledContext(
|
||||||
|
content="Assembled content",
|
||||||
|
total_tokens=100,
|
||||||
|
context_count=2,
|
||||||
|
model="claude-3",
|
||||||
|
)
|
||||||
|
await cache.set_assembled(fp, assembled)
|
||||||
|
|
||||||
|
# Verify setex was called
|
||||||
|
mock_redis.setex.assert_called_once()
|
||||||
|
|
||||||
|
# Mock cache hit
|
||||||
|
mock_redis.get.return_value = assembled.to_json()
|
||||||
|
result = await cache.get_assembled(fp)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.cache_hit is True
|
||||||
|
assert result.content == "Assembled content"
|
||||||
294
backend/tests/services/context/test_compression.py
Normal file
294
backend/tests/services/context/test_compression.py
Normal file
@@ -0,0 +1,294 @@
|
|||||||
|
"""Tests for context compression module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.budget import BudgetAllocator
|
||||||
|
from app.services.context.compression import (
|
||||||
|
ContextCompressor,
|
||||||
|
TruncationResult,
|
||||||
|
TruncationStrategy,
|
||||||
|
)
|
||||||
|
from app.services.context.types import (
|
||||||
|
ContextType,
|
||||||
|
KnowledgeContext,
|
||||||
|
TaskContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncationResult:
|
||||||
|
"""Tests for TruncationResult dataclass."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test basic creation."""
|
||||||
|
result = TruncationResult(
|
||||||
|
original_tokens=100,
|
||||||
|
truncated_tokens=50,
|
||||||
|
content="Truncated content",
|
||||||
|
truncated=True,
|
||||||
|
truncation_ratio=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.original_tokens == 100
|
||||||
|
assert result.truncated_tokens == 50
|
||||||
|
assert result.truncated is True
|
||||||
|
assert result.truncation_ratio == 0.5
|
||||||
|
|
||||||
|
def test_tokens_saved(self) -> None:
|
||||||
|
"""Test tokens_saved property."""
|
||||||
|
result = TruncationResult(
|
||||||
|
original_tokens=100,
|
||||||
|
truncated_tokens=40,
|
||||||
|
content="Test",
|
||||||
|
truncated=True,
|
||||||
|
truncation_ratio=0.6,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.tokens_saved == 60
|
||||||
|
|
||||||
|
def test_no_truncation(self) -> None:
|
||||||
|
"""Test when no truncation needed."""
|
||||||
|
result = TruncationResult(
|
||||||
|
original_tokens=50,
|
||||||
|
truncated_tokens=50,
|
||||||
|
content="Full content",
|
||||||
|
truncated=False,
|
||||||
|
truncation_ratio=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.tokens_saved == 0
|
||||||
|
assert result.truncated is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncationStrategy:
|
||||||
|
"""Tests for TruncationStrategy."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test strategy creation."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
assert strategy._preserve_ratio_start == 0.7
|
||||||
|
assert strategy._min_content_length == 100
|
||||||
|
|
||||||
|
def test_creation_with_params(self) -> None:
|
||||||
|
"""Test strategy creation with custom params."""
|
||||||
|
strategy = TruncationStrategy(
|
||||||
|
preserve_ratio_start=0.5,
|
||||||
|
min_content_length=50,
|
||||||
|
)
|
||||||
|
assert strategy._preserve_ratio_start == 0.5
|
||||||
|
assert strategy._min_content_length == 50
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_empty_content(self) -> None:
|
||||||
|
"""Test truncating empty content."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens("", max_tokens=100)
|
||||||
|
|
||||||
|
assert result.original_tokens == 0
|
||||||
|
assert result.truncated_tokens == 0
|
||||||
|
assert result.content == ""
|
||||||
|
assert result.truncated is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_content_within_limit(self) -> None:
|
||||||
|
"""Test content that fits within limit."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
content = "Short content"
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(content, max_tokens=100)
|
||||||
|
|
||||||
|
assert result.content == content
|
||||||
|
assert result.truncated is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_end_strategy(self) -> None:
|
||||||
|
"""Test end truncation strategy."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
content = "A" * 1000 # Long content
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=50, strategy="end"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.truncated is True
|
||||||
|
assert len(result.content) < len(content)
|
||||||
|
assert strategy.truncation_marker in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_middle_strategy(self) -> None:
|
||||||
|
"""Test middle truncation strategy."""
|
||||||
|
strategy = TruncationStrategy(preserve_ratio_start=0.6)
|
||||||
|
content = "START " + "A" * 500 + " END"
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=50, strategy="middle"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.truncated is True
|
||||||
|
assert strategy.truncation_marker in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_sentence_strategy(self) -> None:
|
||||||
|
"""Test sentence-aware truncation strategy."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
content = "First sentence. Second sentence. Third sentence. Fourth sentence."
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=10, strategy="sentence"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.truncated is True
|
||||||
|
# Should cut at sentence boundary
|
||||||
|
assert (
|
||||||
|
result.content.endswith(".") or strategy.truncation_marker in result.content
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextCompressor:
|
||||||
|
"""Tests for ContextCompressor."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test compressor creation."""
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
assert compressor._truncation is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compress_context_within_limit(self) -> None:
|
||||||
|
"""Test compressing context that already fits."""
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Short content",
|
||||||
|
source="docs",
|
||||||
|
)
|
||||||
|
context.token_count = 5
|
||||||
|
|
||||||
|
result = await compressor.compress_context(context, max_tokens=100)
|
||||||
|
|
||||||
|
# Should return same context unmodified
|
||||||
|
assert result.content == "Short content"
|
||||||
|
assert result.metadata.get("truncated") is not True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compress_context_exceeds_limit(self) -> None:
|
||||||
|
"""Test compressing context that exceeds limit."""
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="A" * 500,
|
||||||
|
source="docs",
|
||||||
|
)
|
||||||
|
context.token_count = 125 # Approximately 500/4
|
||||||
|
|
||||||
|
result = await compressor.compress_context(context, max_tokens=20)
|
||||||
|
|
||||||
|
assert result.metadata.get("truncated") is True
|
||||||
|
assert result.metadata.get("original_tokens") == 125
|
||||||
|
assert len(result.content) < 500
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_compress_contexts_batch(self) -> None:
|
||||||
|
"""Test compressing multiple contexts."""
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(1000)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="A" * 200, source="docs"),
|
||||||
|
KnowledgeContext(content="B" * 200, source="docs"),
|
||||||
|
TaskContext(content="C" * 200, source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await compressor.compress_contexts(contexts, budget)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_strategy_selection_by_type(self) -> None:
|
||||||
|
"""Test that correct strategy is selected for each type."""
|
||||||
|
compressor = ContextCompressor()
|
||||||
|
|
||||||
|
assert compressor._get_strategy_for_type(ContextType.SYSTEM) == "end"
|
||||||
|
assert compressor._get_strategy_for_type(ContextType.TASK) == "end"
|
||||||
|
assert compressor._get_strategy_for_type(ContextType.KNOWLEDGE) == "sentence"
|
||||||
|
assert compressor._get_strategy_for_type(ContextType.CONVERSATION) == "end"
|
||||||
|
assert compressor._get_strategy_for_type(ContextType.TOOL) == "middle"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncationEdgeCases:
|
||||||
|
"""Tests for edge cases in truncation to prevent regressions."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncation_ratio_with_zero_original_tokens(self) -> None:
|
||||||
|
"""Test that truncation ratio handles zero original tokens without division by zero."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
|
||||||
|
# Empty content should not raise ZeroDivisionError
|
||||||
|
result = await strategy.truncate_to_tokens("", max_tokens=100)
|
||||||
|
|
||||||
|
assert result.truncation_ratio == 0.0
|
||||||
|
assert result.original_tokens == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_end_with_zero_available_tokens(self) -> None:
|
||||||
|
"""Test truncation when marker tokens exceed max_tokens."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
content = "Some content to truncate"
|
||||||
|
|
||||||
|
# max_tokens less than marker tokens should return just marker
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=1, strategy="end"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should handle gracefully without crashing
|
||||||
|
assert strategy.truncation_marker in result.content or result.content == content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncate_with_content_that_has_zero_tokens(self) -> None:
|
||||||
|
"""Test truncation when content estimates to zero tokens."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
|
||||||
|
# Very short content that might estimate to 0 tokens
|
||||||
|
result = await strategy.truncate_to_tokens("a", max_tokens=100)
|
||||||
|
|
||||||
|
# Should not raise ZeroDivisionError
|
||||||
|
assert result.content in ("a", "a" + strategy.truncation_marker)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_content_for_tokens_zero_target(self) -> None:
|
||||||
|
"""Test _get_content_for_tokens with zero target tokens."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
|
||||||
|
result = await strategy._get_content_for_tokens(
|
||||||
|
content="Some content",
|
||||||
|
target_tokens=0,
|
||||||
|
from_start=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sentence_truncation_with_no_sentences(self) -> None:
|
||||||
|
"""Test sentence truncation with content that has no sentence boundaries."""
|
||||||
|
strategy = TruncationStrategy()
|
||||||
|
content = "this is content without any sentence ending punctuation"
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=5, strategy="sentence"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should handle gracefully
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_middle_truncation_very_short_content(self) -> None:
|
||||||
|
"""Test middle truncation with content shorter than preserved portions."""
|
||||||
|
strategy = TruncationStrategy(preserve_ratio_start=0.7)
|
||||||
|
content = "ab" # Very short
|
||||||
|
|
||||||
|
result = await strategy.truncate_to_tokens(
|
||||||
|
content, max_tokens=1, strategy="middle"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should handle gracefully without negative indices
|
||||||
|
assert result is not None
|
||||||
243
backend/tests/services/context/test_config.py
Normal file
243
backend/tests/services/context/test_config.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""Tests for context management configuration."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.config import (
|
||||||
|
ContextSettings,
|
||||||
|
get_context_settings,
|
||||||
|
get_default_settings,
|
||||||
|
reset_context_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextSettings:
|
||||||
|
"""Tests for ContextSettings."""
|
||||||
|
|
||||||
|
def test_default_values(self) -> None:
|
||||||
|
"""Test default settings values."""
|
||||||
|
settings = ContextSettings()
|
||||||
|
|
||||||
|
# Budget defaults should sum to 1.0
|
||||||
|
total = (
|
||||||
|
settings.budget_system
|
||||||
|
+ settings.budget_task
|
||||||
|
+ settings.budget_knowledge
|
||||||
|
+ settings.budget_conversation
|
||||||
|
+ settings.budget_tools
|
||||||
|
+ settings.budget_response
|
||||||
|
+ settings.budget_buffer
|
||||||
|
)
|
||||||
|
assert abs(total - 1.0) < 0.001
|
||||||
|
|
||||||
|
# Scoring weights should sum to 1.0
|
||||||
|
weights_total = (
|
||||||
|
settings.scoring_relevance_weight
|
||||||
|
+ settings.scoring_recency_weight
|
||||||
|
+ settings.scoring_priority_weight
|
||||||
|
)
|
||||||
|
assert abs(weights_total - 1.0) < 0.001
|
||||||
|
|
||||||
|
def test_budget_allocation_values(self) -> None:
|
||||||
|
"""Test specific budget allocation values."""
|
||||||
|
settings = ContextSettings()
|
||||||
|
|
||||||
|
assert settings.budget_system == 0.05
|
||||||
|
assert settings.budget_task == 0.10
|
||||||
|
assert settings.budget_knowledge == 0.40
|
||||||
|
assert settings.budget_conversation == 0.20
|
||||||
|
assert settings.budget_tools == 0.05
|
||||||
|
assert settings.budget_response == 0.15
|
||||||
|
assert settings.budget_buffer == 0.05
|
||||||
|
|
||||||
|
def test_scoring_weights(self) -> None:
|
||||||
|
"""Test scoring weights."""
|
||||||
|
settings = ContextSettings()
|
||||||
|
|
||||||
|
assert settings.scoring_relevance_weight == 0.5
|
||||||
|
assert settings.scoring_recency_weight == 0.3
|
||||||
|
assert settings.scoring_priority_weight == 0.2
|
||||||
|
|
||||||
|
def test_cache_settings(self) -> None:
|
||||||
|
"""Test cache settings."""
|
||||||
|
settings = ContextSettings()
|
||||||
|
|
||||||
|
assert settings.cache_enabled is True
|
||||||
|
assert settings.cache_ttl_seconds == 3600
|
||||||
|
assert settings.cache_prefix == "ctx"
|
||||||
|
|
||||||
|
def test_performance_settings(self) -> None:
|
||||||
|
"""Test performance settings."""
|
||||||
|
settings = ContextSettings()
|
||||||
|
|
||||||
|
assert settings.max_assembly_time_ms == 2000
|
||||||
|
assert settings.parallel_scoring is True
|
||||||
|
assert settings.max_parallel_scores == 10
|
||||||
|
|
||||||
|
def test_get_budget_allocation(self) -> None:
|
||||||
|
"""Test get_budget_allocation method."""
|
||||||
|
settings = ContextSettings()
|
||||||
|
allocation = settings.get_budget_allocation()
|
||||||
|
|
||||||
|
assert isinstance(allocation, dict)
|
||||||
|
assert "system" in allocation
|
||||||
|
assert "knowledge" in allocation
|
||||||
|
assert allocation["system"] == 0.05
|
||||||
|
assert allocation["knowledge"] == 0.40
|
||||||
|
|
||||||
|
def test_get_scoring_weights(self) -> None:
|
||||||
|
"""Test get_scoring_weights method."""
|
||||||
|
settings = ContextSettings()
|
||||||
|
weights = settings.get_scoring_weights()
|
||||||
|
|
||||||
|
assert isinstance(weights, dict)
|
||||||
|
assert "relevance" in weights
|
||||||
|
assert "recency" in weights
|
||||||
|
assert "priority" in weights
|
||||||
|
assert weights["relevance"] == 0.5
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""Test to_dict method."""
|
||||||
|
settings = ContextSettings()
|
||||||
|
result = settings.to_dict()
|
||||||
|
|
||||||
|
assert "budget" in result
|
||||||
|
assert "scoring" in result
|
||||||
|
assert "compression" in result
|
||||||
|
assert "cache" in result
|
||||||
|
assert "performance" in result
|
||||||
|
assert "knowledge" in result
|
||||||
|
assert "conversation" in result
|
||||||
|
|
||||||
|
def test_budget_validation_fails_on_wrong_sum(self) -> None:
|
||||||
|
"""Test that budget validation fails when sum != 1.0."""
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
ContextSettings(
|
||||||
|
budget_system=0.5,
|
||||||
|
budget_task=0.5,
|
||||||
|
# Other budgets default to non-zero, so total > 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "sum to 1.0" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_scoring_validation_fails_on_wrong_sum(self) -> None:
|
||||||
|
"""Test that scoring validation fails when sum != 1.0."""
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
ContextSettings(
|
||||||
|
scoring_relevance_weight=0.8,
|
||||||
|
scoring_recency_weight=0.8,
|
||||||
|
scoring_priority_weight=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "sum to 1.0" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_search_type_validation(self) -> None:
|
||||||
|
"""Test search type validation."""
|
||||||
|
# Valid types should work
|
||||||
|
ContextSettings(knowledge_search_type="semantic")
|
||||||
|
ContextSettings(knowledge_search_type="keyword")
|
||||||
|
ContextSettings(knowledge_search_type="hybrid")
|
||||||
|
|
||||||
|
# Invalid type should fail
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ContextSettings(knowledge_search_type="invalid")
|
||||||
|
|
||||||
|
def test_custom_budget_allocation(self) -> None:
|
||||||
|
"""Test custom budget allocation that sums to 1.0."""
|
||||||
|
settings = ContextSettings(
|
||||||
|
budget_system=0.10,
|
||||||
|
budget_task=0.10,
|
||||||
|
budget_knowledge=0.30,
|
||||||
|
budget_conversation=0.25,
|
||||||
|
budget_tools=0.05,
|
||||||
|
budget_response=0.15,
|
||||||
|
budget_buffer=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
total = (
|
||||||
|
settings.budget_system
|
||||||
|
+ settings.budget_task
|
||||||
|
+ settings.budget_knowledge
|
||||||
|
+ settings.budget_conversation
|
||||||
|
+ settings.budget_tools
|
||||||
|
+ settings.budget_response
|
||||||
|
+ settings.budget_buffer
|
||||||
|
)
|
||||||
|
assert abs(total - 1.0) < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsSingleton:
|
||||||
|
"""Tests for settings singleton pattern."""
|
||||||
|
|
||||||
|
def setup_method(self) -> None:
|
||||||
|
"""Reset settings before each test."""
|
||||||
|
reset_context_settings()
|
||||||
|
|
||||||
|
def teardown_method(self) -> None:
|
||||||
|
"""Clean up after each test."""
|
||||||
|
reset_context_settings()
|
||||||
|
|
||||||
|
def test_get_context_settings_returns_instance(self) -> None:
|
||||||
|
"""Test that get_context_settings returns a settings instance."""
|
||||||
|
settings = get_context_settings()
|
||||||
|
assert isinstance(settings, ContextSettings)
|
||||||
|
|
||||||
|
def test_get_context_settings_returns_same_instance(self) -> None:
|
||||||
|
"""Test that get_context_settings returns the same instance."""
|
||||||
|
settings1 = get_context_settings()
|
||||||
|
settings2 = get_context_settings()
|
||||||
|
assert settings1 is settings2
|
||||||
|
|
||||||
|
def test_reset_creates_new_instance(self) -> None:
|
||||||
|
"""Test that reset creates a new instance."""
|
||||||
|
settings1 = get_context_settings()
|
||||||
|
reset_context_settings()
|
||||||
|
settings2 = get_context_settings()
|
||||||
|
|
||||||
|
# Should be different instances
|
||||||
|
assert settings1 is not settings2
|
||||||
|
|
||||||
|
def test_get_default_settings_cached(self) -> None:
|
||||||
|
"""Test that get_default_settings is cached."""
|
||||||
|
settings1 = get_default_settings()
|
||||||
|
settings2 = get_default_settings()
|
||||||
|
assert settings1 is settings2
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnvironmentOverrides:
|
||||||
|
"""Tests for environment variable overrides."""
|
||||||
|
|
||||||
|
def setup_method(self) -> None:
|
||||||
|
"""Reset settings before each test."""
|
||||||
|
reset_context_settings()
|
||||||
|
|
||||||
|
def teardown_method(self) -> None:
|
||||||
|
"""Clean up after each test."""
|
||||||
|
reset_context_settings()
|
||||||
|
# Clean up any env vars we set
|
||||||
|
for key in list(os.environ.keys()):
|
||||||
|
if key.startswith("CTX_"):
|
||||||
|
del os.environ[key]
|
||||||
|
|
||||||
|
def test_env_override_cache_enabled(self) -> None:
|
||||||
|
"""Test that CTX_CACHE_ENABLED env var works."""
|
||||||
|
with patch.dict(os.environ, {"CTX_CACHE_ENABLED": "false"}):
|
||||||
|
reset_context_settings()
|
||||||
|
settings = ContextSettings()
|
||||||
|
assert settings.cache_enabled is False
|
||||||
|
|
||||||
|
def test_env_override_cache_ttl(self) -> None:
|
||||||
|
"""Test that CTX_CACHE_TTL_SECONDS env var works."""
|
||||||
|
with patch.dict(os.environ, {"CTX_CACHE_TTL_SECONDS": "7200"}):
|
||||||
|
reset_context_settings()
|
||||||
|
settings = ContextSettings()
|
||||||
|
assert settings.cache_ttl_seconds == 7200
|
||||||
|
|
||||||
|
def test_env_override_max_assembly_time(self) -> None:
|
||||||
|
"""Test that CTX_MAX_ASSEMBLY_TIME_MS env var works."""
|
||||||
|
with patch.dict(os.environ, {"CTX_MAX_ASSEMBLY_TIME_MS": "200"}):
|
||||||
|
reset_context_settings()
|
||||||
|
settings = ContextSettings()
|
||||||
|
assert settings.max_assembly_time_ms == 200
|
||||||
456
backend/tests/services/context/test_engine.py
Normal file
456
backend/tests/services/context/test_engine.py
Normal file
@@ -0,0 +1,456 @@
|
|||||||
|
"""Tests for ContextEngine."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.config import ContextSettings
|
||||||
|
from app.services.context.engine import ContextEngine, create_context_engine
|
||||||
|
from app.services.context.types import (
|
||||||
|
AssembledContext,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
ToolContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineCreation:
|
||||||
|
"""Tests for ContextEngine creation."""
|
||||||
|
|
||||||
|
def test_creation_minimal(self) -> None:
|
||||||
|
"""Test creating engine with minimal config."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
assert engine._mcp is None
|
||||||
|
assert engine._settings is not None
|
||||||
|
assert engine._calculator is not None
|
||||||
|
assert engine._scorer is not None
|
||||||
|
assert engine._ranker is not None
|
||||||
|
assert engine._compressor is not None
|
||||||
|
assert engine._cache is not None
|
||||||
|
assert engine._pipeline is not None
|
||||||
|
|
||||||
|
def test_creation_with_settings(self) -> None:
|
||||||
|
"""Test creating engine with custom settings."""
|
||||||
|
settings = ContextSettings(
|
||||||
|
compression_threshold=0.7,
|
||||||
|
cache_enabled=False,
|
||||||
|
)
|
||||||
|
engine = ContextEngine(settings=settings)
|
||||||
|
|
||||||
|
assert engine._settings.compression_threshold == 0.7
|
||||||
|
assert engine._settings.cache_enabled is False
|
||||||
|
|
||||||
|
def test_creation_with_redis(self) -> None:
|
||||||
|
"""Test creating engine with Redis."""
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
assert engine._cache.is_enabled
|
||||||
|
|
||||||
|
def test_set_mcp_manager(self) -> None:
|
||||||
|
"""Test setting MCP manager."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
|
||||||
|
engine.set_mcp_manager(mock_mcp)
|
||||||
|
|
||||||
|
assert engine._mcp is mock_mcp
|
||||||
|
|
||||||
|
def test_set_redis(self) -> None:
|
||||||
|
"""Test setting Redis connection."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
mock_redis = MagicMock()
|
||||||
|
|
||||||
|
engine.set_redis(mock_redis)
|
||||||
|
|
||||||
|
assert engine._cache._redis is mock_redis
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineHelpers:
|
||||||
|
"""Tests for ContextEngine helper methods."""
|
||||||
|
|
||||||
|
def test_convert_conversation(self) -> None:
|
||||||
|
"""Test converting conversation history."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
history = [
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
{"role": "user", "content": "How are you?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
contexts = engine._convert_conversation(history)
|
||||||
|
|
||||||
|
assert len(contexts) == 3
|
||||||
|
assert all(isinstance(c, ConversationContext) for c in contexts)
|
||||||
|
assert contexts[0].role == MessageRole.USER
|
||||||
|
assert contexts[1].role == MessageRole.ASSISTANT
|
||||||
|
assert contexts[0].content == "Hello!"
|
||||||
|
assert contexts[0].metadata["turn"] == 0
|
||||||
|
|
||||||
|
def test_convert_tool_results(self) -> None:
|
||||||
|
"""Test converting tool results."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
results = [
|
||||||
|
{"tool_name": "search", "content": "Result 1", "status": "success"},
|
||||||
|
{"tool_name": "read", "result": {"file": "test.txt"}, "status": "success"},
|
||||||
|
]
|
||||||
|
|
||||||
|
contexts = engine._convert_tool_results(results)
|
||||||
|
|
||||||
|
assert len(contexts) == 2
|
||||||
|
assert all(isinstance(c, ToolContext) for c in contexts)
|
||||||
|
assert contexts[0].content == "Result 1"
|
||||||
|
assert contexts[0].metadata["tool_name"] == "search"
|
||||||
|
# Dict content should be JSON serialized
|
||||||
|
assert "file" in contexts[1].content
|
||||||
|
assert "test.txt" in contexts[1].content
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineAssembly:
|
||||||
|
"""Tests for context assembly."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_minimal(self) -> None:
|
||||||
|
"""Test assembling with minimal inputs."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test query",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
use_cache=False, # Disable cache for test
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, AssembledContext)
|
||||||
|
assert result.context_count == 0 # No contexts provided
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_system_prompt(self) -> None:
|
||||||
|
"""Test assembling with system prompt."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test query",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="You are a helpful assistant.",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 1
|
||||||
|
assert "helpful assistant" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_task(self) -> None:
|
||||||
|
"""Test assembling with task description."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="implement feature",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
task_description="Implement user authentication",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 1
|
||||||
|
assert "authentication" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_conversation(self) -> None:
|
||||||
|
"""Test assembling with conversation history."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="continue",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
conversation_history=[
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
{"role": "assistant", "content": "Hi!"},
|
||||||
|
],
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 2
|
||||||
|
assert "Hello" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_tool_results(self) -> None:
|
||||||
|
"""Test assembling with tool results."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="continue",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
tool_results=[
|
||||||
|
{"tool_name": "search", "content": "Found 5 results"},
|
||||||
|
],
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 1
|
||||||
|
assert "Found 5 results" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_with_custom_contexts(self) -> None:
|
||||||
|
"""Test assembling with custom contexts."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
custom = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Custom knowledge.",
|
||||||
|
source="custom",
|
||||||
|
relevance_score=0.9,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
custom_contexts=custom,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count == 1
|
||||||
|
assert "Custom knowledge" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_assemble_full_workflow(self) -> None:
|
||||||
|
"""Test full assembly workflow."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="implement login",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="You are an expert Python developer.",
|
||||||
|
task_description="Implement user authentication.",
|
||||||
|
conversation_history=[
|
||||||
|
{"role": "user", "content": "Can you help me implement JWT auth?"},
|
||||||
|
],
|
||||||
|
tool_results=[
|
||||||
|
{"tool_name": "file_create", "content": "Created auth.py"},
|
||||||
|
],
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.context_count >= 4
|
||||||
|
assert result.total_tokens > 0
|
||||||
|
assert result.model == "claude-3-sonnet"
|
||||||
|
|
||||||
|
# Check for expected content
|
||||||
|
assert "expert Python developer" in result.content
|
||||||
|
assert "authentication" in result.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineKnowledge:
|
||||||
|
"""Tests for knowledge fetching."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_knowledge_no_mcp(self) -> None:
|
||||||
|
"""Test fetching knowledge without MCP returns empty."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine._fetch_knowledge(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_knowledge_with_mcp(self) -> None:
|
||||||
|
"""Test fetching knowledge with MCP."""
|
||||||
|
mock_mcp = AsyncMock()
|
||||||
|
mock_mcp.call_tool.return_value.data = {
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"content": "Document content",
|
||||||
|
"source_path": "docs/api.md",
|
||||||
|
"score": 0.9,
|
||||||
|
"chunk_id": "chunk-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"content": "Another document",
|
||||||
|
"source_path": "docs/auth.md",
|
||||||
|
"score": 0.8,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
engine = ContextEngine(mcp_manager=mock_mcp)
|
||||||
|
|
||||||
|
result = await engine._fetch_knowledge(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="authentication",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert all(isinstance(c, KnowledgeContext) for c in result)
|
||||||
|
assert result[0].content == "Document content"
|
||||||
|
assert result[0].source == "docs/api.md"
|
||||||
|
assert result[0].relevance_score == 0.9
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_knowledge_error_handling(self) -> None:
|
||||||
|
"""Test knowledge fetch error handling."""
|
||||||
|
mock_mcp = AsyncMock()
|
||||||
|
mock_mcp.call_tool.side_effect = Exception("MCP error")
|
||||||
|
|
||||||
|
engine = ContextEngine(mcp_manager=mock_mcp)
|
||||||
|
|
||||||
|
# Should not raise, returns empty
|
||||||
|
result = await engine._fetch_knowledge(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineCaching:
|
||||||
|
"""Tests for caching behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_disabled(self) -> None:
|
||||||
|
"""Test assembly with cache disabled."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
result = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="Test prompt",
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert not result.cache_hit
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit(self) -> None:
|
||||||
|
"""Test cache hit."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
# First call - cache miss
|
||||||
|
mock_redis.get.return_value = None
|
||||||
|
|
||||||
|
result1 = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="Test prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Second call - mock cache hit
|
||||||
|
mock_redis.get.return_value = result1.to_json()
|
||||||
|
|
||||||
|
result2 = await engine.assemble_context(
|
||||||
|
project_id="proj-123",
|
||||||
|
agent_id="agent-456",
|
||||||
|
query="test",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
system_prompt="Test prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result2.cache_hit
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextEngineUtilities:
|
||||||
|
"""Tests for utility methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_budget_for_model(self) -> None:
|
||||||
|
"""Test getting budget for model."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
budget = await engine.get_budget_for_model("claude-3-sonnet")
|
||||||
|
|
||||||
|
assert budget.total > 0
|
||||||
|
assert budget.system > 0
|
||||||
|
assert budget.knowledge > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_budget_with_max_tokens(self) -> None:
|
||||||
|
"""Test getting budget with max tokens."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
budget = await engine.get_budget_for_model("claude-3-sonnet", max_tokens=5000)
|
||||||
|
|
||||||
|
assert budget.total == 5000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_count_tokens(self) -> None:
|
||||||
|
"""Test token counting."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
count = await engine.count_tokens("Hello world")
|
||||||
|
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalidate_cache(self) -> None:
|
||||||
|
"""Test cache invalidation."""
|
||||||
|
mock_redis = AsyncMock()
|
||||||
|
|
||||||
|
async def mock_scan_iter(match=None):
|
||||||
|
for key in ["ctx:1", "ctx:2"]:
|
||||||
|
yield key
|
||||||
|
|
||||||
|
mock_redis.scan_iter = mock_scan_iter
|
||||||
|
|
||||||
|
settings = ContextSettings(cache_enabled=True)
|
||||||
|
engine = ContextEngine(redis=mock_redis, settings=settings)
|
||||||
|
|
||||||
|
deleted = await engine.invalidate_cache(pattern="*test*")
|
||||||
|
|
||||||
|
assert deleted >= 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_stats(self) -> None:
|
||||||
|
"""Test getting engine stats."""
|
||||||
|
engine = ContextEngine()
|
||||||
|
|
||||||
|
stats = await engine.get_stats()
|
||||||
|
|
||||||
|
assert "cache" in stats
|
||||||
|
assert "settings" in stats
|
||||||
|
assert "compression_threshold" in stats["settings"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateContextEngine:
|
||||||
|
"""Tests for factory function."""
|
||||||
|
|
||||||
|
def test_create_context_engine(self) -> None:
|
||||||
|
"""Test factory function."""
|
||||||
|
engine = create_context_engine()
|
||||||
|
|
||||||
|
assert isinstance(engine, ContextEngine)
|
||||||
|
|
||||||
|
def test_create_context_engine_with_settings(self) -> None:
|
||||||
|
"""Test factory with settings."""
|
||||||
|
settings = ContextSettings(cache_enabled=False)
|
||||||
|
engine = create_context_engine(settings=settings)
|
||||||
|
|
||||||
|
assert engine._settings.cache_enabled is False
|
||||||
250
backend/tests/services/context/test_exceptions.py
Normal file
250
backend/tests/services/context/test_exceptions.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""Tests for context management exceptions."""
|
||||||
|
|
||||||
|
from app.services.context.exceptions import (
|
||||||
|
AssemblyTimeoutError,
|
||||||
|
BudgetExceededError,
|
||||||
|
CacheError,
|
||||||
|
CompressionError,
|
||||||
|
ContextError,
|
||||||
|
ContextNotFoundError,
|
||||||
|
FormattingError,
|
||||||
|
InvalidContextError,
|
||||||
|
ScoringError,
|
||||||
|
TokenCountError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextError:
|
||||||
|
"""Tests for base ContextError."""
|
||||||
|
|
||||||
|
def test_basic_initialization(self) -> None:
|
||||||
|
"""Test basic error initialization."""
|
||||||
|
error = ContextError("Test error")
|
||||||
|
assert error.message == "Test error"
|
||||||
|
assert error.details == {}
|
||||||
|
assert str(error) == "Test error"
|
||||||
|
|
||||||
|
def test_with_details(self) -> None:
|
||||||
|
"""Test error with details."""
|
||||||
|
error = ContextError("Test error", {"key": "value", "count": 42})
|
||||||
|
assert error.details == {"key": "value", "count": 42}
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""Test conversion to dictionary."""
|
||||||
|
error = ContextError("Test error", {"key": "value"})
|
||||||
|
result = error.to_dict()
|
||||||
|
|
||||||
|
assert result["error_type"] == "ContextError"
|
||||||
|
assert result["message"] == "Test error"
|
||||||
|
assert result["details"] == {"key": "value"}
|
||||||
|
|
||||||
|
def test_inheritance(self) -> None:
|
||||||
|
"""Test that ContextError inherits from Exception."""
|
||||||
|
error = ContextError("Test")
|
||||||
|
assert isinstance(error, Exception)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBudgetExceededError:
|
||||||
|
"""Tests for BudgetExceededError."""
|
||||||
|
|
||||||
|
def test_default_message(self) -> None:
|
||||||
|
"""Test default error message."""
|
||||||
|
error = BudgetExceededError()
|
||||||
|
assert "exceeded" in error.message.lower()
|
||||||
|
|
||||||
|
def test_with_budget_info(self) -> None:
|
||||||
|
"""Test with budget information."""
|
||||||
|
error = BudgetExceededError(
|
||||||
|
allocated=1000,
|
||||||
|
requested=1500,
|
||||||
|
context_type="knowledge",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.allocated == 1000
|
||||||
|
assert error.requested == 1500
|
||||||
|
assert error.context_type == "knowledge"
|
||||||
|
assert error.details["overage"] == 500
|
||||||
|
|
||||||
|
def test_to_dict_includes_budget_info(self) -> None:
|
||||||
|
"""Test that to_dict includes budget info."""
|
||||||
|
error = BudgetExceededError(
|
||||||
|
allocated=1000,
|
||||||
|
requested=1500,
|
||||||
|
)
|
||||||
|
result = error.to_dict()
|
||||||
|
|
||||||
|
assert result["details"]["allocated"] == 1000
|
||||||
|
assert result["details"]["requested"] == 1500
|
||||||
|
assert result["details"]["overage"] == 500
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenCountError:
|
||||||
|
"""Tests for TokenCountError."""
|
||||||
|
|
||||||
|
def test_basic_error(self) -> None:
|
||||||
|
"""Test basic token count error."""
|
||||||
|
error = TokenCountError()
|
||||||
|
assert "token" in error.message.lower()
|
||||||
|
|
||||||
|
def test_with_model_info(self) -> None:
|
||||||
|
"""Test with model information."""
|
||||||
|
error = TokenCountError(
|
||||||
|
message="Failed to count",
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
text_length=5000,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.model == "claude-3-sonnet"
|
||||||
|
assert error.text_length == 5000
|
||||||
|
assert error.details["model"] == "claude-3-sonnet"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompressionError:
|
||||||
|
"""Tests for CompressionError."""
|
||||||
|
|
||||||
|
def test_basic_error(self) -> None:
|
||||||
|
"""Test basic compression error."""
|
||||||
|
error = CompressionError()
|
||||||
|
assert "compress" in error.message.lower()
|
||||||
|
|
||||||
|
def test_with_token_info(self) -> None:
|
||||||
|
"""Test with token information."""
|
||||||
|
error = CompressionError(
|
||||||
|
original_tokens=2000,
|
||||||
|
target_tokens=1000,
|
||||||
|
achieved_tokens=1500,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.original_tokens == 2000
|
||||||
|
assert error.target_tokens == 1000
|
||||||
|
assert error.achieved_tokens == 1500
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssemblyTimeoutError:
|
||||||
|
"""Tests for AssemblyTimeoutError."""
|
||||||
|
|
||||||
|
def test_basic_error(self) -> None:
|
||||||
|
"""Test basic timeout error."""
|
||||||
|
error = AssemblyTimeoutError()
|
||||||
|
assert "timed out" in error.message.lower()
|
||||||
|
|
||||||
|
def test_with_timing_info(self) -> None:
|
||||||
|
"""Test with timing information."""
|
||||||
|
error = AssemblyTimeoutError(
|
||||||
|
timeout_ms=100,
|
||||||
|
elapsed_ms=150.5,
|
||||||
|
stage="scoring",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.timeout_ms == 100
|
||||||
|
assert error.elapsed_ms == 150.5
|
||||||
|
assert error.stage == "scoring"
|
||||||
|
assert error.details["stage"] == "scoring"
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoringError:
|
||||||
|
"""Tests for ScoringError."""
|
||||||
|
|
||||||
|
def test_basic_error(self) -> None:
|
||||||
|
"""Test basic scoring error."""
|
||||||
|
error = ScoringError()
|
||||||
|
assert "score" in error.message.lower()
|
||||||
|
|
||||||
|
def test_with_scorer_info(self) -> None:
|
||||||
|
"""Test with scorer information."""
|
||||||
|
error = ScoringError(
|
||||||
|
scorer_type="relevance",
|
||||||
|
context_id="ctx-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.scorer_type == "relevance"
|
||||||
|
assert error.context_id == "ctx-123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormattingError:
|
||||||
|
"""Tests for FormattingError."""
|
||||||
|
|
||||||
|
def test_basic_error(self) -> None:
|
||||||
|
"""Test basic formatting error."""
|
||||||
|
error = FormattingError()
|
||||||
|
assert "format" in error.message.lower()
|
||||||
|
|
||||||
|
def test_with_model_info(self) -> None:
|
||||||
|
"""Test with model information."""
|
||||||
|
error = FormattingError(
|
||||||
|
model="claude-3-opus",
|
||||||
|
adapter="ClaudeAdapter",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.model == "claude-3-opus"
|
||||||
|
assert error.adapter == "ClaudeAdapter"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheError:
|
||||||
|
"""Tests for CacheError."""
|
||||||
|
|
||||||
|
def test_basic_error(self) -> None:
|
||||||
|
"""Test basic cache error."""
|
||||||
|
error = CacheError()
|
||||||
|
assert "cache" in error.message.lower()
|
||||||
|
|
||||||
|
def test_with_operation_info(self) -> None:
|
||||||
|
"""Test with operation information."""
|
||||||
|
error = CacheError(
|
||||||
|
operation="get",
|
||||||
|
cache_key="ctx:abc123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.operation == "get"
|
||||||
|
assert error.cache_key == "ctx:abc123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextNotFoundError:
|
||||||
|
"""Tests for ContextNotFoundError."""
|
||||||
|
|
||||||
|
def test_basic_error(self) -> None:
|
||||||
|
"""Test basic not found error."""
|
||||||
|
error = ContextNotFoundError()
|
||||||
|
assert "not found" in error.message.lower()
|
||||||
|
|
||||||
|
def test_with_source_info(self) -> None:
|
||||||
|
"""Test with source information."""
|
||||||
|
error = ContextNotFoundError(
|
||||||
|
source="knowledge-base",
|
||||||
|
query="authentication flow",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.source == "knowledge-base"
|
||||||
|
assert error.query == "authentication flow"
|
||||||
|
|
||||||
|
|
||||||
|
class TestInvalidContextError:
|
||||||
|
"""Tests for InvalidContextError."""
|
||||||
|
|
||||||
|
def test_basic_error(self) -> None:
|
||||||
|
"""Test basic invalid error."""
|
||||||
|
error = InvalidContextError()
|
||||||
|
assert "invalid" in error.message.lower()
|
||||||
|
|
||||||
|
def test_with_field_info(self) -> None:
|
||||||
|
"""Test with field information."""
|
||||||
|
error = InvalidContextError(
|
||||||
|
field="content",
|
||||||
|
value="",
|
||||||
|
reason="Content cannot be empty",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert error.field == "content"
|
||||||
|
assert error.value == ""
|
||||||
|
assert error.reason == "Content cannot be empty"
|
||||||
|
|
||||||
|
def test_value_type_only_in_details(self) -> None:
|
||||||
|
"""Test that only value type is included in details (not actual value)."""
|
||||||
|
error = InvalidContextError(
|
||||||
|
field="api_key",
|
||||||
|
value="secret-key-here",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Actual value should not be in details
|
||||||
|
assert "secret-key-here" not in str(error.details)
|
||||||
|
assert error.details["value_type"] == "str"
|
||||||
499
backend/tests/services/context/test_ranker.py
Normal file
499
backend/tests/services/context/test_ranker.py
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
"""Tests for context ranking module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.budget import BudgetAllocator, TokenBudget
|
||||||
|
from app.services.context.prioritization import ContextRanker, RankingResult
|
||||||
|
from app.services.context.scoring import CompositeScorer, ScoredContext
|
||||||
|
from app.services.context.types import (
|
||||||
|
ContextPriority,
|
||||||
|
ContextType,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRankingResult:
|
||||||
|
"""Tests for RankingResult dataclass."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test RankingResult creation."""
|
||||||
|
ctx = TaskContext(content="Test", source="task")
|
||||||
|
scored = ScoredContext(context=ctx, composite_score=0.8)
|
||||||
|
|
||||||
|
result = RankingResult(
|
||||||
|
selected=[scored],
|
||||||
|
excluded=[],
|
||||||
|
total_tokens=100,
|
||||||
|
selection_stats={"total": 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(result.selected) == 1
|
||||||
|
assert result.total_tokens == 100
|
||||||
|
|
||||||
|
def test_selected_contexts_property(self) -> None:
|
||||||
|
"""Test selected_contexts property extracts contexts."""
|
||||||
|
ctx1 = TaskContext(content="Test 1", source="task")
|
||||||
|
ctx2 = TaskContext(content="Test 2", source="task")
|
||||||
|
|
||||||
|
scored1 = ScoredContext(context=ctx1, composite_score=0.8)
|
||||||
|
scored2 = ScoredContext(context=ctx2, composite_score=0.6)
|
||||||
|
|
||||||
|
result = RankingResult(
|
||||||
|
selected=[scored1, scored2],
|
||||||
|
excluded=[],
|
||||||
|
total_tokens=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
selected = result.selected_contexts
|
||||||
|
assert len(selected) == 2
|
||||||
|
assert ctx1 in selected
|
||||||
|
assert ctx2 in selected
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextRanker:
|
||||||
|
"""Tests for ContextRanker."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test ranker creation."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
assert ranker._scorer is not None
|
||||||
|
assert ranker._calculator is not None
|
||||||
|
|
||||||
|
def test_creation_with_scorer(self) -> None:
|
||||||
|
"""Test ranker creation with custom scorer."""
|
||||||
|
scorer = CompositeScorer(relevance_weight=0.8)
|
||||||
|
ranker = ContextRanker(scorer=scorer)
|
||||||
|
assert ranker._scorer is scorer
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_empty_contexts(self) -> None:
|
||||||
|
"""Test ranking empty context list."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
result = await ranker.rank([], "query", budget)
|
||||||
|
|
||||||
|
assert len(result.selected) == 0
|
||||||
|
assert len(result.excluded) == 0
|
||||||
|
assert result.total_tokens == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_single_context_fits(self) -> None:
|
||||||
|
"""Test ranking single context that fits budget."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Short content",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await ranker.rank([context], "query", budget)
|
||||||
|
|
||||||
|
assert len(result.selected) == 1
|
||||||
|
assert len(result.excluded) == 0
|
||||||
|
assert result.selected[0].context is context
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_respects_budget(self) -> None:
|
||||||
|
"""Test that ranking respects token budget."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
# Create a very small budget
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=100,
|
||||||
|
knowledge=50, # Only 50 tokens for knowledge
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create contexts that exceed budget
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="A" * 200, # ~50 tokens
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="B" * 200, # ~50 tokens
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="C" * 200, # ~50 tokens
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.7,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget)
|
||||||
|
|
||||||
|
# Not all should fit
|
||||||
|
assert len(result.selected) < len(contexts)
|
||||||
|
assert len(result.excluded) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_selects_highest_scores(self) -> None:
|
||||||
|
"""Test that ranking selects highest scored contexts."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(1000)
|
||||||
|
|
||||||
|
# Small budget for knowledge
|
||||||
|
budget.knowledge = 100
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Low score",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.2,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="High score",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Medium score",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget)
|
||||||
|
|
||||||
|
# Should have selected some
|
||||||
|
if result.selected:
|
||||||
|
# The highest scored should be selected first
|
||||||
|
scores = [s.composite_score for s in result.selected]
|
||||||
|
assert scores == sorted(scores, reverse=True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_critical_priority_always_included(self) -> None:
|
||||||
|
"""Test that CRITICAL priority contexts are always included."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
# Very small budget
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=100,
|
||||||
|
system=10, # Very small
|
||||||
|
knowledge=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="Critical system prompt that must be included",
|
||||||
|
source="system",
|
||||||
|
priority=ContextPriority.CRITICAL.value,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Optional knowledge",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget, ensure_required=True)
|
||||||
|
|
||||||
|
# Critical context should be in selected
|
||||||
|
selected_priorities = [s.context.priority for s in result.selected]
|
||||||
|
assert ContextPriority.CRITICAL.value in selected_priorities
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_without_ensure_required(self) -> None:
|
||||||
|
"""Test ranking without ensuring required contexts."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
budget = TokenBudget(
|
||||||
|
total=100,
|
||||||
|
system=50,
|
||||||
|
knowledge=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="A" * 500, # Large content
|
||||||
|
source="system",
|
||||||
|
priority=ContextPriority.CRITICAL.value,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Short",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget, ensure_required=False)
|
||||||
|
|
||||||
|
# Without ensure_required, CRITICAL contexts can be excluded
|
||||||
|
# if budget doesn't allow
|
||||||
|
assert len(result.selected) + len(result.excluded) == len(contexts)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_selection_stats(self) -> None:
|
||||||
|
"""Test that ranking provides useful statistics."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="Knowledge 1", source="docs", relevance_score=0.8),
|
||||||
|
KnowledgeContext(content="Knowledge 2", source="docs", relevance_score=0.6),
|
||||||
|
TaskContext(content="Task", source="task"),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget)
|
||||||
|
|
||||||
|
stats = result.selection_stats
|
||||||
|
assert "total_contexts" in stats
|
||||||
|
assert "selected_count" in stats
|
||||||
|
assert "excluded_count" in stats
|
||||||
|
assert "total_tokens" in stats
|
||||||
|
assert "by_type" in stats
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_simple(self) -> None:
|
||||||
|
"""Test simple ranking without budget per type."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="A",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="B",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.7,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="C",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank_simple(contexts, "query", max_tokens=1000)
|
||||||
|
|
||||||
|
# Should return contexts sorted by score
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_simple_respects_max_tokens(self) -> None:
|
||||||
|
"""Test that simple ranking respects max tokens."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
# Create contexts with known token counts
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="A" * 100, # ~25 tokens
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="B" * 100,
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="C" * 100,
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.7,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Very small limit
|
||||||
|
result = await ranker.rank_simple(contexts, "query", max_tokens=30)
|
||||||
|
|
||||||
|
# Should only fit a limited number
|
||||||
|
assert len(result) <= len(contexts)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_simple_empty(self) -> None:
|
||||||
|
"""Test simple ranking with empty list."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
result = await ranker.rank_simple([], "query", max_tokens=1000)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rerank_for_diversity(self) -> None:
|
||||||
|
"""Test diversity reranking."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
# Create scored contexts from same source
|
||||||
|
contexts = [
|
||||||
|
ScoredContext(
|
||||||
|
context=KnowledgeContext(
|
||||||
|
content=f"Content {i}",
|
||||||
|
source="same-source",
|
||||||
|
relevance_score=0.9 - i * 0.1,
|
||||||
|
),
|
||||||
|
composite_score=0.9 - i * 0.1,
|
||||||
|
)
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Limit to 2 per source
|
||||||
|
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
||||||
|
|
||||||
|
assert len(result) == 5
|
||||||
|
# First 2 should be from same source, rest deferred
|
||||||
|
first_two_sources = [r.context.source for r in result[:2]]
|
||||||
|
assert all(s == "same-source" for s in first_two_sources)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rerank_for_diversity_multiple_sources(self) -> None:
|
||||||
|
"""Test diversity reranking with multiple sources."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
ScoredContext(
|
||||||
|
context=KnowledgeContext(
|
||||||
|
content="Source A - 1",
|
||||||
|
source="source-a",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
composite_score=0.9,
|
||||||
|
),
|
||||||
|
ScoredContext(
|
||||||
|
context=KnowledgeContext(
|
||||||
|
content="Source A - 2",
|
||||||
|
source="source-a",
|
||||||
|
relevance_score=0.8,
|
||||||
|
),
|
||||||
|
composite_score=0.8,
|
||||||
|
),
|
||||||
|
ScoredContext(
|
||||||
|
context=KnowledgeContext(
|
||||||
|
content="Source B - 1",
|
||||||
|
source="source-b",
|
||||||
|
relevance_score=0.7,
|
||||||
|
),
|
||||||
|
composite_score=0.7,
|
||||||
|
),
|
||||||
|
ScoredContext(
|
||||||
|
context=KnowledgeContext(
|
||||||
|
content="Source A - 3",
|
||||||
|
source="source-a",
|
||||||
|
relevance_score=0.6,
|
||||||
|
),
|
||||||
|
composite_score=0.6,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rerank_for_diversity(contexts, max_per_source=2)
|
||||||
|
|
||||||
|
# Should not have more than 2 from source-a in first 3
|
||||||
|
source_a_in_first_3 = sum(
|
||||||
|
1 for r in result[:3] if r.context.source == "source-a"
|
||||||
|
)
|
||||||
|
assert source_a_in_first_3 <= 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_counts_set(self) -> None:
|
||||||
|
"""Test that token counts are set during ranking."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Token count should be None initially
|
||||||
|
assert context.token_count is None
|
||||||
|
|
||||||
|
await ranker.rank([context], "query", budget)
|
||||||
|
|
||||||
|
# Token count should be set after ranking
|
||||||
|
assert context.token_count is not None
|
||||||
|
assert context.token_count > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextRankerIntegration:
|
||||||
|
"""Integration tests for context ranking."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_ranking_workflow(self) -> None:
|
||||||
|
"""Test complete ranking workflow."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(10000)
|
||||||
|
|
||||||
|
# Create diverse context types
|
||||||
|
contexts = [
|
||||||
|
SystemContext(
|
||||||
|
content="You are a helpful assistant.",
|
||||||
|
source="system",
|
||||||
|
priority=ContextPriority.CRITICAL.value,
|
||||||
|
),
|
||||||
|
TaskContext(
|
||||||
|
content="Help the user with their coding question.",
|
||||||
|
source="task",
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Python is a programming language.",
|
||||||
|
source="docs/python.md",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Java is also a programming language.",
|
||||||
|
source="docs/java.md",
|
||||||
|
relevance_score=0.4,
|
||||||
|
),
|
||||||
|
ConversationContext(
|
||||||
|
content="Hello, can you help me?",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "Python help", budget)
|
||||||
|
|
||||||
|
# System (CRITICAL) should be included
|
||||||
|
selected_types = [s.context.get_type() for s in result.selected]
|
||||||
|
assert ContextType.SYSTEM in selected_types
|
||||||
|
|
||||||
|
# Stats should be populated
|
||||||
|
assert result.selection_stats["total_contexts"] == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ranking_preserves_context_order_by_score(self) -> None:
|
||||||
|
"""Test that ranking orders by score correctly."""
|
||||||
|
ranker = ContextRanker()
|
||||||
|
allocator = BudgetAllocator()
|
||||||
|
budget = allocator.create_budget(100000)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Low",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.2,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="High",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Medium",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await ranker.rank(contexts, "query", budget)
|
||||||
|
|
||||||
|
# Verify ordering is by score
|
||||||
|
scores = [s.composite_score for s in result.selected]
|
||||||
|
assert scores == sorted(scores, reverse=True)
|
||||||
893
backend/tests/services/context/test_scoring.py
Normal file
893
backend/tests/services/context/test_scoring.py
Normal file
@@ -0,0 +1,893 @@
|
|||||||
|
"""Tests for context scoring module."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.scoring import (
|
||||||
|
CompositeScorer,
|
||||||
|
PriorityScorer,
|
||||||
|
RecencyScorer,
|
||||||
|
RelevanceScorer,
|
||||||
|
ScoredContext,
|
||||||
|
)
|
||||||
|
from app.services.context.types import (
|
||||||
|
ContextPriority,
|
||||||
|
ContextType,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRelevanceScorer:
|
||||||
|
"""Tests for RelevanceScorer."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test scorer creation."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
assert scorer.weight == 1.0
|
||||||
|
|
||||||
|
def test_creation_with_weight(self) -> None:
|
||||||
|
"""Test scorer creation with custom weight."""
|
||||||
|
scorer = RelevanceScorer(weight=0.5)
|
||||||
|
assert scorer.weight == 0.5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_precomputed_relevance(self) -> None:
|
||||||
|
"""Test scoring with pre-computed relevance score."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
# KnowledgeContext with pre-computed score
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content about Python",
|
||||||
|
source="docs/python.md",
|
||||||
|
relevance_score=0.85,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "Python programming")
|
||||||
|
assert score == 0.85
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_metadata_score(self) -> None:
|
||||||
|
"""Test scoring with metadata-provided score."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
context = SystemContext(
|
||||||
|
content="System prompt",
|
||||||
|
source="system",
|
||||||
|
metadata={"relevance_score": 0.9},
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "anything")
|
||||||
|
assert score == 0.9
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_fallback_to_keyword_matching(self) -> None:
|
||||||
|
"""Test fallback to keyword matching when no score available."""
|
||||||
|
scorer = RelevanceScorer(keyword_fallback_weight=0.5)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Implement authentication with JWT tokens",
|
||||||
|
source="task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Query has matching keywords
|
||||||
|
score = await scorer.score(context, "JWT authentication")
|
||||||
|
assert score > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keyword_matching_no_overlap(self) -> None:
|
||||||
|
"""Test keyword matching with no query overlap."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Implement database migration",
|
||||||
|
source="task",
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "xyz abc 123")
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_keyword_matching_full_overlap(self) -> None:
|
||||||
|
"""Test keyword matching with high overlap."""
|
||||||
|
scorer = RelevanceScorer(keyword_fallback_weight=1.0)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="python programming language",
|
||||||
|
source="task",
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "python programming")
|
||||||
|
# Should have high score due to keyword overlap
|
||||||
|
assert score > 0.5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_mcp_success(self) -> None:
|
||||||
|
"""Test scoring with MCP semantic similarity."""
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.success = True
|
||||||
|
mock_result.data = {"similarity": 0.75}
|
||||||
|
mock_mcp.call_tool = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
|
scorer = RelevanceScorer(mcp_manager=mock_mcp)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Test content",
|
||||||
|
source="task",
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "test query")
|
||||||
|
assert score == 0.75
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_mcp_failure_fallback(self) -> None:
|
||||||
|
"""Test fallback when MCP fails."""
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
mock_mcp.call_tool = AsyncMock(side_effect=Exception("Connection failed"))
|
||||||
|
|
||||||
|
scorer = RelevanceScorer(mcp_manager=mock_mcp, keyword_fallback_weight=0.5)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Python programming code",
|
||||||
|
source="task",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fall back to keyword matching
|
||||||
|
score = await scorer.score(context, "Python code")
|
||||||
|
assert score > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_batch(self) -> None:
|
||||||
|
"""Test batch scoring."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="Python", source="1", relevance_score=0.8),
|
||||||
|
KnowledgeContext(content="Java", source="2", relevance_score=0.6),
|
||||||
|
KnowledgeContext(content="Go", source="3", relevance_score=0.9),
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = await scorer.score_batch(contexts, "test")
|
||||||
|
assert len(scores) == 3
|
||||||
|
assert scores[0] == 0.8
|
||||||
|
assert scores[1] == 0.6
|
||||||
|
assert scores[2] == 0.9
|
||||||
|
|
||||||
|
def test_set_mcp_manager(self) -> None:
|
||||||
|
"""Test setting MCP manager."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
assert scorer._mcp is None
|
||||||
|
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
scorer.set_mcp_manager(mock_mcp)
|
||||||
|
assert scorer._mcp is mock_mcp
|
||||||
|
|
||||||
|
|
||||||
|
class TestRecencyScorer:
|
||||||
|
"""Tests for RecencyScorer."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test scorer creation."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
assert scorer.weight == 1.0
|
||||||
|
assert scorer._half_life_hours == 24.0
|
||||||
|
|
||||||
|
def test_creation_with_custom_half_life(self) -> None:
|
||||||
|
"""Test scorer creation with custom half-life."""
|
||||||
|
scorer = RecencyScorer(half_life_hours=12.0)
|
||||||
|
assert scorer._half_life_hours == 12.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_recent_context(self) -> None:
|
||||||
|
"""Test scoring a very recent context."""
|
||||||
|
scorer = RecencyScorer(half_life_hours=24.0)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Recent task",
|
||||||
|
source="task",
|
||||||
|
timestamp=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query", reference_time=now)
|
||||||
|
# Very recent should have score near 1.0
|
||||||
|
assert score > 0.99
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_at_half_life(self) -> None:
|
||||||
|
"""Test scoring at exactly half-life age."""
|
||||||
|
scorer = RecencyScorer(half_life_hours=24.0)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
half_life_ago = now - timedelta(hours=24)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Day old task",
|
||||||
|
source="task",
|
||||||
|
timestamp=half_life_ago,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query", reference_time=now)
|
||||||
|
# At half-life, score should be ~0.5
|
||||||
|
assert 0.49 <= score <= 0.51
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_old_context(self) -> None:
|
||||||
|
"""Test scoring a very old context."""
|
||||||
|
scorer = RecencyScorer(half_life_hours=24.0)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
week_ago = now - timedelta(days=7)
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Week old task",
|
||||||
|
source="task",
|
||||||
|
timestamp=week_ago,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query", reference_time=now)
|
||||||
|
# 7 days with 24h half-life = very low score
|
||||||
|
assert score < 0.01
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_type_specific_half_lives(self) -> None:
|
||||||
|
"""Test that different context types have different half-lives."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
one_hour_ago = now - timedelta(hours=1)
|
||||||
|
|
||||||
|
# Conversation has 1 hour half-life by default
|
||||||
|
conv_context = ConversationContext(
|
||||||
|
content="Hello",
|
||||||
|
source="chat",
|
||||||
|
role=MessageRole.USER,
|
||||||
|
timestamp=one_hour_ago,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Knowledge has 168 hour (1 week) half-life by default
|
||||||
|
knowledge_context = KnowledgeContext(
|
||||||
|
content="Documentation",
|
||||||
|
source="docs",
|
||||||
|
timestamp=one_hour_ago,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_score = await scorer.score(conv_context, "query", reference_time=now)
|
||||||
|
knowledge_score = await scorer.score(
|
||||||
|
knowledge_context, "query", reference_time=now
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conversation should decay much faster
|
||||||
|
assert conv_score < knowledge_score
|
||||||
|
|
||||||
|
def test_get_half_life(self) -> None:
|
||||||
|
"""Test getting half-life for context type."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
|
||||||
|
assert scorer.get_half_life(ContextType.CONVERSATION) == 1.0
|
||||||
|
assert scorer.get_half_life(ContextType.KNOWLEDGE) == 168.0
|
||||||
|
assert scorer.get_half_life(ContextType.SYSTEM) == 720.0
|
||||||
|
|
||||||
|
def test_set_half_life(self) -> None:
|
||||||
|
"""Test setting custom half-life."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
|
||||||
|
scorer.set_half_life(ContextType.TASK, 48.0)
|
||||||
|
assert scorer.get_half_life(ContextType.TASK) == 48.0
|
||||||
|
|
||||||
|
def test_set_half_life_invalid(self) -> None:
|
||||||
|
"""Test setting invalid half-life."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.set_half_life(ContextType.TASK, 0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.set_half_life(ContextType.TASK, -1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_batch(self) -> None:
|
||||||
|
"""Test batch scoring."""
|
||||||
|
scorer = RecencyScorer()
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
TaskContext(content="1", source="t", timestamp=now),
|
||||||
|
TaskContext(content="2", source="t", timestamp=now - timedelta(hours=24)),
|
||||||
|
TaskContext(content="3", source="t", timestamp=now - timedelta(hours=48)),
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = await scorer.score_batch(contexts, "query", reference_time=now)
|
||||||
|
assert len(scores) == 3
|
||||||
|
# Scores should be in descending order (more recent = higher)
|
||||||
|
assert scores[0] > scores[1] > scores[2]
|
||||||
|
|
||||||
|
|
||||||
|
class TestPriorityScorer:
|
||||||
|
"""Tests for PriorityScorer."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test scorer creation."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
assert scorer.weight == 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_critical_priority(self) -> None:
|
||||||
|
"""Test scoring CRITICAL priority context."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
context = SystemContext(
|
||||||
|
content="Critical system prompt",
|
||||||
|
source="system",
|
||||||
|
priority=ContextPriority.CRITICAL.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query")
|
||||||
|
# CRITICAL (100) + type bonus should be > 1.0, normalized to 1.0
|
||||||
|
assert score == 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_normal_priority(self) -> None:
|
||||||
|
"""Test scoring NORMAL priority context."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
context = TaskContext(
|
||||||
|
content="Normal task",
|
||||||
|
source="task",
|
||||||
|
priority=ContextPriority.NORMAL.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query")
|
||||||
|
# NORMAL (50) = 0.5, plus TASK bonus (0.15) = 0.65
|
||||||
|
assert 0.6 <= score <= 0.7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_low_priority(self) -> None:
|
||||||
|
"""Test scoring LOW priority context."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Low priority knowledge",
|
||||||
|
source="docs",
|
||||||
|
priority=ContextPriority.LOW.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "query")
|
||||||
|
# LOW (20) = 0.2, no bonus for KNOWLEDGE
|
||||||
|
assert 0.15 <= score <= 0.25
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_type_bonuses(self) -> None:
|
||||||
|
"""Test type-specific priority bonuses."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
# All with same base priority
|
||||||
|
system_ctx = SystemContext(
|
||||||
|
content="System",
|
||||||
|
source="system",
|
||||||
|
priority=50,
|
||||||
|
)
|
||||||
|
task_ctx = TaskContext(
|
||||||
|
content="Task",
|
||||||
|
source="task",
|
||||||
|
priority=50,
|
||||||
|
)
|
||||||
|
knowledge_ctx = KnowledgeContext(
|
||||||
|
content="Knowledge",
|
||||||
|
source="docs",
|
||||||
|
priority=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
system_score = await scorer.score(system_ctx, "query")
|
||||||
|
task_score = await scorer.score(task_ctx, "query")
|
||||||
|
knowledge_score = await scorer.score(knowledge_ctx, "query")
|
||||||
|
|
||||||
|
# System has highest bonus (0.2), task next (0.15), knowledge has none
|
||||||
|
assert system_score > task_score > knowledge_score
|
||||||
|
|
||||||
|
def test_get_type_bonus(self) -> None:
|
||||||
|
"""Test getting type bonus."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
assert scorer.get_type_bonus(ContextType.SYSTEM) == 0.2
|
||||||
|
assert scorer.get_type_bonus(ContextType.TASK) == 0.15
|
||||||
|
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.0
|
||||||
|
|
||||||
|
def test_set_type_bonus(self) -> None:
|
||||||
|
"""Test setting custom type bonus."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
scorer.set_type_bonus(ContextType.KNOWLEDGE, 0.1)
|
||||||
|
assert scorer.get_type_bonus(ContextType.KNOWLEDGE) == 0.1
|
||||||
|
|
||||||
|
def test_set_type_bonus_invalid(self) -> None:
|
||||||
|
"""Test setting invalid type bonus."""
|
||||||
|
scorer = PriorityScorer()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.set_type_bonus(ContextType.KNOWLEDGE, 1.5)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.set_type_bonus(ContextType.KNOWLEDGE, -0.1)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompositeScorer:
|
||||||
|
"""Tests for CompositeScorer."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test scorer creation with default weights."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
weights = scorer.weights
|
||||||
|
assert weights["relevance"] == 0.5
|
||||||
|
assert weights["recency"] == 0.3
|
||||||
|
assert weights["priority"] == 0.2
|
||||||
|
|
||||||
|
def test_creation_with_custom_weights(self) -> None:
|
||||||
|
"""Test scorer creation with custom weights."""
|
||||||
|
scorer = CompositeScorer(
|
||||||
|
relevance_weight=0.6,
|
||||||
|
recency_weight=0.2,
|
||||||
|
priority_weight=0.2,
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = scorer.weights
|
||||||
|
assert weights["relevance"] == 0.6
|
||||||
|
assert weights["recency"] == 0.2
|
||||||
|
assert weights["priority"] == 0.2
|
||||||
|
|
||||||
|
def test_update_weights(self) -> None:
|
||||||
|
"""Test updating weights."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
scorer.update_weights(relevance=0.7, recency=0.2, priority=0.1)
|
||||||
|
|
||||||
|
weights = scorer.weights
|
||||||
|
assert weights["relevance"] == 0.7
|
||||||
|
assert weights["recency"] == 0.2
|
||||||
|
assert weights["priority"] == 0.1
|
||||||
|
|
||||||
|
def test_update_weights_partial(self) -> None:
|
||||||
|
"""Test partially updating weights."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
original_recency = scorer.weights["recency"]
|
||||||
|
|
||||||
|
scorer.update_weights(relevance=0.8)
|
||||||
|
|
||||||
|
assert scorer.weights["relevance"] == 0.8
|
||||||
|
assert scorer.weights["recency"] == original_recency
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_basic(self) -> None:
|
||||||
|
"""Test basic composite scoring."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
priority=ContextPriority.NORMAL.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = await scorer.score(context, "test query")
|
||||||
|
assert 0.0 <= score <= 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_details(self) -> None:
|
||||||
|
"""Test scoring with detailed breakdown."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
timestamp=datetime.now(UTC),
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
scored = await scorer.score_with_details(context, "test query")
|
||||||
|
|
||||||
|
assert isinstance(scored, ScoredContext)
|
||||||
|
assert scored.context is context
|
||||||
|
assert 0.0 <= scored.composite_score <= 1.0
|
||||||
|
assert scored.relevance_score == 0.8
|
||||||
|
assert scored.recency_score > 0.9 # Very recent
|
||||||
|
assert scored.priority_score > 0.5 # HIGH priority
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_not_cached_on_context(self) -> None:
|
||||||
|
"""Test that scores are NOT cached on the context.
|
||||||
|
|
||||||
|
Scores should not be cached on the context because they are query-dependent.
|
||||||
|
Different queries would get incorrect cached scores if we cached on the context.
|
||||||
|
"""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# After scoring, context._score should remain None
|
||||||
|
# (we don't cache on context because scores are query-dependent)
|
||||||
|
await scorer.score(context, "query")
|
||||||
|
# The scorer should compute fresh scores each time
|
||||||
|
# rather than caching on the context object
|
||||||
|
|
||||||
|
# Score again with different query - should compute fresh score
|
||||||
|
score1 = await scorer.score(context, "query 1")
|
||||||
|
score2 = await scorer.score(context, "query 2")
|
||||||
|
# Both should be valid scores (not necessarily equal since queries differ)
|
||||||
|
assert 0.0 <= score1 <= 1.0
|
||||||
|
assert 0.0 <= score2 <= 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_batch(self) -> None:
|
||||||
|
"""Test batch scoring."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="High relevance",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.9,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Low relevance",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.2,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
scored = await scorer.score_batch(contexts, "query")
|
||||||
|
assert len(scored) == 2
|
||||||
|
assert scored[0].relevance_score > scored[1].relevance_score
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank(self) -> None:
|
||||||
|
"""Test ranking contexts."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="Low", source="docs", relevance_score=0.2),
|
||||||
|
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
|
||||||
|
KnowledgeContext(content="Medium", source="docs", relevance_score=0.5),
|
||||||
|
]
|
||||||
|
|
||||||
|
ranked = await scorer.rank(contexts, "query")
|
||||||
|
|
||||||
|
# Should be sorted by score (highest first)
|
||||||
|
assert len(ranked) == 3
|
||||||
|
assert ranked[0].relevance_score == 0.9
|
||||||
|
assert ranked[1].relevance_score == 0.5
|
||||||
|
assert ranked[2].relevance_score == 0.2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_with_limit(self) -> None:
|
||||||
|
"""Test ranking with limit."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content=str(i), source="docs", relevance_score=i / 10)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
ranked = await scorer.rank(contexts, "query", limit=3)
|
||||||
|
assert len(ranked) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rank_with_min_score(self) -> None:
|
||||||
|
"""Test ranking with minimum score threshold."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(content="Low", source="docs", relevance_score=0.1),
|
||||||
|
KnowledgeContext(content="High", source="docs", relevance_score=0.9),
|
||||||
|
]
|
||||||
|
|
||||||
|
ranked = await scorer.rank(contexts, "query", min_score=0.5)
|
||||||
|
|
||||||
|
# Only the high relevance context should pass the threshold
|
||||||
|
assert len(ranked) <= 2 # Could be 1 if min_score filters
|
||||||
|
|
||||||
|
def test_set_mcp_manager(self) -> None:
|
||||||
|
"""Test setting MCP manager."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
mock_mcp = MagicMock()
|
||||||
|
|
||||||
|
scorer.set_mcp_manager(mock_mcp)
|
||||||
|
assert scorer._relevance_scorer._mcp is mock_mcp
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_scoring_same_context_no_race(self) -> None:
|
||||||
|
"""Test that concurrent scoring of the same context doesn't cause race conditions.
|
||||||
|
|
||||||
|
This verifies that the per-context locking mechanism prevents the same context
|
||||||
|
from being scored multiple times when scored concurrently.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
# Use scorer with recency_weight=0 to eliminate time-dependent variation
|
||||||
|
# (recency scores change as time passes between calls)
|
||||||
|
scorer = CompositeScorer(
|
||||||
|
relevance_weight=0.5,
|
||||||
|
recency_weight=0.0, # Disable recency to get deterministic results
|
||||||
|
priority_weight=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a single context that will be scored multiple times concurrently
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content for race condition test",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.75,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Score the same context many times in parallel
|
||||||
|
num_concurrent = 50
|
||||||
|
tasks = [scorer.score(context, "test query") for _ in range(num_concurrent)]
|
||||||
|
scores = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# All scores should be identical (deterministic scoring without recency)
|
||||||
|
assert all(s == scores[0] for s in scores)
|
||||||
|
# Note: We don't cache _score on context because scores are query-dependent
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_scoring_different_contexts(self) -> None:
|
||||||
|
"""Test that concurrent scoring of different contexts works correctly.
|
||||||
|
|
||||||
|
Different contexts should not interfere with each other during parallel scoring.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
# Create many different contexts
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content=f"Test content {i}",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=i / 10,
|
||||||
|
)
|
||||||
|
for i in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Score all contexts concurrently
|
||||||
|
tasks = [scorer.score(ctx, "test query") for ctx in contexts]
|
||||||
|
scores = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Each context should have a different score based on its relevance
|
||||||
|
assert len(set(scores)) > 1 # Not all the same
|
||||||
|
# Note: We don't cache _score on context because scores are query-dependent
|
||||||
|
|
||||||
|
|
||||||
|
class TestScoredContext:
|
||||||
|
"""Tests for ScoredContext dataclass."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test ScoredContext creation."""
|
||||||
|
context = TaskContext(content="Test", source="task")
|
||||||
|
scored = ScoredContext(
|
||||||
|
context=context,
|
||||||
|
composite_score=0.75,
|
||||||
|
relevance_score=0.8,
|
||||||
|
recency_score=0.7,
|
||||||
|
priority_score=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert scored.context is context
|
||||||
|
assert scored.composite_score == 0.75
|
||||||
|
|
||||||
|
def test_comparison_operators(self) -> None:
|
||||||
|
"""Test comparison operators for sorting."""
|
||||||
|
ctx1 = TaskContext(content="1", source="task")
|
||||||
|
ctx2 = TaskContext(content="2", source="task")
|
||||||
|
|
||||||
|
scored1 = ScoredContext(context=ctx1, composite_score=0.5)
|
||||||
|
scored2 = ScoredContext(context=ctx2, composite_score=0.8)
|
||||||
|
|
||||||
|
assert scored1 < scored2
|
||||||
|
assert scored2 > scored1
|
||||||
|
|
||||||
|
def test_sorting(self) -> None:
|
||||||
|
"""Test sorting scored contexts."""
|
||||||
|
contexts = [
|
||||||
|
ScoredContext(
|
||||||
|
context=TaskContext(content="Low", source="task"),
|
||||||
|
composite_score=0.3,
|
||||||
|
),
|
||||||
|
ScoredContext(
|
||||||
|
context=TaskContext(content="High", source="task"),
|
||||||
|
composite_score=0.9,
|
||||||
|
),
|
||||||
|
ScoredContext(
|
||||||
|
context=TaskContext(content="Medium", source="task"),
|
||||||
|
composite_score=0.6,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
sorted_contexts = sorted(contexts, reverse=True)
|
||||||
|
|
||||||
|
assert sorted_contexts[0].composite_score == 0.9
|
||||||
|
assert sorted_contexts[1].composite_score == 0.6
|
||||||
|
assert sorted_contexts[2].composite_score == 0.3
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseScorer:
|
||||||
|
"""Tests for BaseScorer abstract class."""
|
||||||
|
|
||||||
|
def test_weight_property(self) -> None:
|
||||||
|
"""Test weight property."""
|
||||||
|
# Use a concrete implementation
|
||||||
|
scorer = RelevanceScorer(weight=0.7)
|
||||||
|
assert scorer.weight == 0.7
|
||||||
|
|
||||||
|
def test_weight_setter_valid(self) -> None:
|
||||||
|
"""Test weight setter with valid values."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
scorer.weight = 0.5
|
||||||
|
assert scorer.weight == 0.5
|
||||||
|
|
||||||
|
def test_weight_setter_invalid(self) -> None:
|
||||||
|
"""Test weight setter with invalid values."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.weight = -0.1
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
scorer.weight = 1.5
|
||||||
|
|
||||||
|
def test_normalize_score(self) -> None:
|
||||||
|
"""Test score normalization."""
|
||||||
|
scorer = RelevanceScorer()
|
||||||
|
|
||||||
|
# Normal range
|
||||||
|
assert scorer.normalize_score(0.5) == 0.5
|
||||||
|
|
||||||
|
# Below 0
|
||||||
|
assert scorer.normalize_score(-0.5) == 0.0
|
||||||
|
|
||||||
|
# Above 1
|
||||||
|
assert scorer.normalize_score(1.5) == 1.0
|
||||||
|
|
||||||
|
# Boundaries
|
||||||
|
assert scorer.normalize_score(0.0) == 0.0
|
||||||
|
assert scorer.normalize_score(1.0) == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompositeScorerEdgeCases:
|
||||||
|
"""Tests for CompositeScorer edge cases and lock management."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_with_zero_weights(self) -> None:
|
||||||
|
"""Test scoring when all weights are zero."""
|
||||||
|
scorer = CompositeScorer(
|
||||||
|
relevance_weight=0.0,
|
||||||
|
recency_weight=0.0,
|
||||||
|
priority_weight=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test content",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return 0.0 when total weight is 0
|
||||||
|
score = await scorer.score(context, "test query")
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_score_batch_sequential(self) -> None:
|
||||||
|
"""Test batch scoring in sequential mode (parallel=False)."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
contexts = [
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Content 1",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.8,
|
||||||
|
),
|
||||||
|
KnowledgeContext(
|
||||||
|
content="Content 2",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Use parallel=False to cover the sequential path
|
||||||
|
scored = await scorer.score_batch(contexts, "query", parallel=False)
|
||||||
|
|
||||||
|
assert len(scored) == 2
|
||||||
|
assert scored[0].relevance_score == 0.8
|
||||||
|
assert scored[1].relevance_score == 0.5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_fast_path_reuse(self) -> None:
|
||||||
|
"""Test that existing locks are reused via fast path."""
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
context = KnowledgeContext(
|
||||||
|
content="Test",
|
||||||
|
source="docs",
|
||||||
|
relevance_score=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First access creates the lock
|
||||||
|
lock1 = await scorer._get_context_lock(context.id)
|
||||||
|
|
||||||
|
# Second access should hit the fast path (lock exists in dict)
|
||||||
|
lock2 = await scorer._get_context_lock(context.id)
|
||||||
|
|
||||||
|
assert lock2 is lock1 # Same lock object returned
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_cleanup_when_limit_reached(self) -> None:
|
||||||
|
"""Test that old locks are cleaned up when limit is reached."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Create scorer with very low max_locks to trigger cleanup
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
scorer._max_locks = 3
|
||||||
|
scorer._lock_ttl = 0.1 # 100ms TTL
|
||||||
|
|
||||||
|
# Create locks for several context IDs
|
||||||
|
context_ids = [f"ctx-{i}" for i in range(5)]
|
||||||
|
|
||||||
|
# Get locks for first 3 contexts (fill up to limit)
|
||||||
|
for ctx_id in context_ids[:3]:
|
||||||
|
await scorer._get_context_lock(ctx_id)
|
||||||
|
|
||||||
|
# Wait for TTL to expire
|
||||||
|
time.sleep(0.15)
|
||||||
|
|
||||||
|
# Getting a lock for a new context should trigger cleanup
|
||||||
|
await scorer._get_context_lock(context_ids[3])
|
||||||
|
|
||||||
|
# Some old locks should have been cleaned up
|
||||||
|
# The exact number depends on cleanup logic
|
||||||
|
assert len(scorer._context_locks) <= scorer._max_locks + 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_cleanup_preserves_held_locks(self) -> None:
|
||||||
|
"""Test that cleanup doesn't remove locks that are currently held."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
scorer._max_locks = 2
|
||||||
|
scorer._lock_ttl = 0.05 # 50ms TTL
|
||||||
|
|
||||||
|
# Get and hold lock1
|
||||||
|
lock1 = await scorer._get_context_lock("ctx-1")
|
||||||
|
async with lock1:
|
||||||
|
# While holding lock1, add more locks
|
||||||
|
await scorer._get_context_lock("ctx-2")
|
||||||
|
time.sleep(0.1) # Let TTL expire
|
||||||
|
# Adding another should trigger cleanup
|
||||||
|
await scorer._get_context_lock("ctx-3")
|
||||||
|
|
||||||
|
# lock1 should still exist (it's held)
|
||||||
|
assert any(lock is lock1 for lock, _ in scorer._context_locks.values())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_lock_acquisition_double_check(self) -> None:
|
||||||
|
"""Test that concurrent lock acquisition uses double-check pattern."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
scorer = CompositeScorer()
|
||||||
|
|
||||||
|
context_id = "test-context-id"
|
||||||
|
|
||||||
|
# Simulate concurrent lock acquisition
|
||||||
|
async def get_lock():
|
||||||
|
return await scorer._get_context_lock(context_id)
|
||||||
|
|
||||||
|
locks = await asyncio.gather(*[get_lock() for _ in range(10)])
|
||||||
|
|
||||||
|
# All should get the same lock (double-check pattern ensures this)
|
||||||
|
assert all(lock is locks[0] for lock in locks)
|
||||||
570
backend/tests/services/context/test_types.py
Normal file
570
backend/tests/services/context/test_types.py
Normal file
@@ -0,0 +1,570 @@
|
|||||||
|
"""Tests for context types."""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.context.types import (
|
||||||
|
AssembledContext,
|
||||||
|
ContextPriority,
|
||||||
|
ContextType,
|
||||||
|
ConversationContext,
|
||||||
|
KnowledgeContext,
|
||||||
|
MessageRole,
|
||||||
|
SystemContext,
|
||||||
|
TaskContext,
|
||||||
|
TaskStatus,
|
||||||
|
ToolContext,
|
||||||
|
ToolResultStatus,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextType:
|
||||||
|
"""Tests for ContextType enum."""
|
||||||
|
|
||||||
|
def test_all_types_exist(self) -> None:
|
||||||
|
"""Test that all expected context types exist."""
|
||||||
|
assert ContextType.SYSTEM
|
||||||
|
assert ContextType.TASK
|
||||||
|
assert ContextType.KNOWLEDGE
|
||||||
|
assert ContextType.CONVERSATION
|
||||||
|
assert ContextType.TOOL
|
||||||
|
|
||||||
|
def test_from_string_valid(self) -> None:
|
||||||
|
"""Test from_string with valid values."""
|
||||||
|
assert ContextType.from_string("system") == ContextType.SYSTEM
|
||||||
|
assert ContextType.from_string("KNOWLEDGE") == ContextType.KNOWLEDGE
|
||||||
|
assert ContextType.from_string("Task") == ContextType.TASK
|
||||||
|
|
||||||
|
def test_from_string_invalid(self) -> None:
|
||||||
|
"""Test from_string with invalid value."""
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
ContextType.from_string("invalid")
|
||||||
|
assert "Invalid context type" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextPriority:
|
||||||
|
"""Tests for ContextPriority enum."""
|
||||||
|
|
||||||
|
def test_priority_ordering(self) -> None:
|
||||||
|
"""Test that priorities are ordered correctly."""
|
||||||
|
assert ContextPriority.LOWEST.value < ContextPriority.LOW.value
|
||||||
|
assert ContextPriority.LOW.value < ContextPriority.NORMAL.value
|
||||||
|
assert ContextPriority.NORMAL.value < ContextPriority.HIGH.value
|
||||||
|
assert ContextPriority.HIGH.value < ContextPriority.HIGHEST.value
|
||||||
|
assert ContextPriority.HIGHEST.value < ContextPriority.CRITICAL.value
|
||||||
|
|
||||||
|
def test_from_int(self) -> None:
|
||||||
|
"""Test from_int conversion."""
|
||||||
|
assert ContextPriority.from_int(0) == ContextPriority.LOWEST
|
||||||
|
assert ContextPriority.from_int(50) == ContextPriority.NORMAL
|
||||||
|
assert ContextPriority.from_int(100) == ContextPriority.HIGHEST
|
||||||
|
assert ContextPriority.from_int(200) == ContextPriority.CRITICAL
|
||||||
|
|
||||||
|
def test_from_int_intermediate(self) -> None:
|
||||||
|
"""Test from_int with intermediate values."""
|
||||||
|
# Should return closest lower priority
|
||||||
|
assert ContextPriority.from_int(30) == ContextPriority.LOW
|
||||||
|
assert ContextPriority.from_int(60) == ContextPriority.NORMAL
|
||||||
|
|
||||||
|
|
||||||
|
class TestSystemContext:
|
||||||
|
"""Tests for SystemContext."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test basic creation."""
|
||||||
|
ctx = SystemContext(
|
||||||
|
content="You are a helpful assistant.",
|
||||||
|
source="system_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.content == "You are a helpful assistant."
|
||||||
|
assert ctx.source == "system_prompt"
|
||||||
|
assert ctx.get_type() == ContextType.SYSTEM
|
||||||
|
|
||||||
|
def test_default_high_priority(self) -> None:
|
||||||
|
"""Test that system context defaults to high priority."""
|
||||||
|
ctx = SystemContext(content="Test", source="test")
|
||||||
|
assert ctx.priority == ContextPriority.HIGH.value
|
||||||
|
|
||||||
|
def test_create_persona(self) -> None:
|
||||||
|
"""Test create_persona factory method."""
|
||||||
|
ctx = SystemContext.create_persona(
|
||||||
|
name="Code Assistant",
|
||||||
|
description="A helpful coding assistant.",
|
||||||
|
capabilities=["Write code", "Debug"],
|
||||||
|
constraints=["Never expose secrets"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Code Assistant" in ctx.content
|
||||||
|
assert "helpful coding assistant" in ctx.content
|
||||||
|
assert "Write code" in ctx.content
|
||||||
|
assert "Never expose secrets" in ctx.content
|
||||||
|
assert ctx.priority == ContextPriority.HIGHEST.value
|
||||||
|
|
||||||
|
def test_create_instructions(self) -> None:
|
||||||
|
"""Test create_instructions factory method."""
|
||||||
|
ctx = SystemContext.create_instructions(
|
||||||
|
["Always be helpful", "Be concise"],
|
||||||
|
source="rules",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Always be helpful" in ctx.content
|
||||||
|
assert "Be concise" in ctx.content
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""Test serialization to dict."""
|
||||||
|
ctx = SystemContext(
|
||||||
|
content="Test",
|
||||||
|
source="test",
|
||||||
|
role="assistant",
|
||||||
|
instructions_type="general",
|
||||||
|
)
|
||||||
|
|
||||||
|
data = ctx.to_dict()
|
||||||
|
assert data["role"] == "assistant"
|
||||||
|
assert data["instructions_type"] == "general"
|
||||||
|
assert data["type"] == "system"
|
||||||
|
|
||||||
|
|
||||||
|
class TestKnowledgeContext:
|
||||||
|
"""Tests for KnowledgeContext."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test basic creation."""
|
||||||
|
ctx = KnowledgeContext(
|
||||||
|
content="def authenticate(user): ...",
|
||||||
|
source="/src/auth.py",
|
||||||
|
collection="code",
|
||||||
|
file_type="python",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.content == "def authenticate(user): ..."
|
||||||
|
assert ctx.source == "/src/auth.py"
|
||||||
|
assert ctx.collection == "code"
|
||||||
|
assert ctx.get_type() == ContextType.KNOWLEDGE
|
||||||
|
|
||||||
|
def test_from_search_result(self) -> None:
|
||||||
|
"""Test from_search_result factory method."""
|
||||||
|
result = {
|
||||||
|
"content": "Test content",
|
||||||
|
"source_path": "/test/file.py",
|
||||||
|
"collection": "code",
|
||||||
|
"file_type": "python",
|
||||||
|
"chunk_index": 2,
|
||||||
|
"score": 0.85,
|
||||||
|
"id": "chunk-123",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = KnowledgeContext.from_search_result(result, "test query")
|
||||||
|
|
||||||
|
assert ctx.content == "Test content"
|
||||||
|
assert ctx.source == "/test/file.py"
|
||||||
|
assert ctx.relevance_score == 0.85
|
||||||
|
assert ctx.search_query == "test query"
|
||||||
|
|
||||||
|
def test_from_search_results(self) -> None:
|
||||||
|
"""Test from_search_results factory method."""
|
||||||
|
results = [
|
||||||
|
{"content": "Content 1", "source_path": "/a.py", "score": 0.9},
|
||||||
|
{"content": "Content 2", "source_path": "/b.py", "score": 0.8},
|
||||||
|
]
|
||||||
|
|
||||||
|
contexts = KnowledgeContext.from_search_results(results, "query")
|
||||||
|
|
||||||
|
assert len(contexts) == 2
|
||||||
|
assert contexts[0].relevance_score == 0.9
|
||||||
|
assert contexts[1].source == "/b.py"
|
||||||
|
|
||||||
|
def test_is_code(self) -> None:
|
||||||
|
"""Test is_code method."""
|
||||||
|
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
|
||||||
|
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
|
||||||
|
|
||||||
|
assert code_ctx.is_code() is True
|
||||||
|
assert doc_ctx.is_code() is False
|
||||||
|
|
||||||
|
def test_is_documentation(self) -> None:
|
||||||
|
"""Test is_documentation method."""
|
||||||
|
doc_ctx = KnowledgeContext(content="docs", source="test", file_type="markdown")
|
||||||
|
code_ctx = KnowledgeContext(content="code", source="test", file_type="python")
|
||||||
|
|
||||||
|
assert doc_ctx.is_documentation() is True
|
||||||
|
assert code_ctx.is_documentation() is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestConversationContext:
|
||||||
|
"""Tests for ConversationContext."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test basic creation."""
|
||||||
|
ctx = ConversationContext(
|
||||||
|
content="Hello, how can I help?",
|
||||||
|
source="conversation",
|
||||||
|
role=MessageRole.ASSISTANT,
|
||||||
|
turn_index=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.content == "Hello, how can I help?"
|
||||||
|
assert ctx.role == MessageRole.ASSISTANT
|
||||||
|
assert ctx.get_type() == ContextType.CONVERSATION
|
||||||
|
|
||||||
|
def test_from_message(self) -> None:
|
||||||
|
"""Test from_message factory method."""
|
||||||
|
ctx = ConversationContext.from_message(
|
||||||
|
content="What is Python?",
|
||||||
|
role="user",
|
||||||
|
turn_index=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.content == "What is Python?"
|
||||||
|
assert ctx.role == MessageRole.USER
|
||||||
|
assert ctx.turn_index == 0
|
||||||
|
|
||||||
|
def test_from_history(self) -> None:
|
||||||
|
"""Test from_history factory method."""
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
{"role": "user", "content": "Help me"},
|
||||||
|
]
|
||||||
|
|
||||||
|
contexts = ConversationContext.from_history(messages)
|
||||||
|
|
||||||
|
assert len(contexts) == 3
|
||||||
|
assert contexts[0].role == MessageRole.USER
|
||||||
|
assert contexts[1].role == MessageRole.ASSISTANT
|
||||||
|
assert contexts[2].turn_index == 2
|
||||||
|
|
||||||
|
def test_is_user_message(self) -> None:
|
||||||
|
"""Test is_user_message method."""
|
||||||
|
user_ctx = ConversationContext(
|
||||||
|
content="test", source="test", role=MessageRole.USER
|
||||||
|
)
|
||||||
|
assistant_ctx = ConversationContext(
|
||||||
|
content="test", source="test", role=MessageRole.ASSISTANT
|
||||||
|
)
|
||||||
|
|
||||||
|
assert user_ctx.is_user_message() is True
|
||||||
|
assert assistant_ctx.is_user_message() is False
|
||||||
|
|
||||||
|
def test_format_for_prompt(self) -> None:
|
||||||
|
"""Test format_for_prompt method."""
|
||||||
|
ctx = ConversationContext.from_message(
|
||||||
|
content="What is 2+2?",
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted = ctx.format_for_prompt()
|
||||||
|
assert "User:" in formatted
|
||||||
|
assert "What is 2+2?" in formatted
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskContext:
|
||||||
|
"""Tests for TaskContext."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test basic creation."""
|
||||||
|
ctx = TaskContext(
|
||||||
|
content="Implement login feature",
|
||||||
|
source="task",
|
||||||
|
title="Login Feature",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.content == "Implement login feature"
|
||||||
|
assert ctx.title == "Login Feature"
|
||||||
|
assert ctx.get_type() == ContextType.TASK
|
||||||
|
|
||||||
|
def test_default_normal_priority(self) -> None:
|
||||||
|
"""Test that task context uses NORMAL priority from base class."""
|
||||||
|
ctx = TaskContext(content="Test", source="test")
|
||||||
|
# TaskContext inherits NORMAL priority from BaseContext
|
||||||
|
# Use TaskContext.create() for default HIGH priority behavior
|
||||||
|
assert ctx.priority == ContextPriority.NORMAL.value
|
||||||
|
|
||||||
|
def test_explicit_high_priority(self) -> None:
|
||||||
|
"""Test setting explicit HIGH priority."""
|
||||||
|
ctx = TaskContext(
|
||||||
|
content="Test",
|
||||||
|
source="test",
|
||||||
|
priority=ContextPriority.HIGH.value,
|
||||||
|
)
|
||||||
|
assert ctx.priority == ContextPriority.HIGH.value
|
||||||
|
|
||||||
|
def test_create_factory(self) -> None:
|
||||||
|
"""Test create factory method."""
|
||||||
|
ctx = TaskContext.create(
|
||||||
|
title="Add Auth",
|
||||||
|
description="Implement authentication",
|
||||||
|
acceptance_criteria=["Tests pass", "Code reviewed"],
|
||||||
|
constraints=["Use JWT"],
|
||||||
|
issue_id="123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.title == "Add Auth"
|
||||||
|
assert ctx.content == "Implement authentication"
|
||||||
|
assert len(ctx.acceptance_criteria) == 2
|
||||||
|
assert "Use JWT" in ctx.constraints
|
||||||
|
assert ctx.status == TaskStatus.IN_PROGRESS
|
||||||
|
|
||||||
|
def test_format_for_prompt(self) -> None:
|
||||||
|
"""Test format_for_prompt method."""
|
||||||
|
ctx = TaskContext.create(
|
||||||
|
title="Test Task",
|
||||||
|
description="Do something",
|
||||||
|
acceptance_criteria=["Works correctly"],
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted = ctx.format_for_prompt()
|
||||||
|
assert "Task: Test Task" in formatted
|
||||||
|
assert "Do something" in formatted
|
||||||
|
assert "Works correctly" in formatted
|
||||||
|
|
||||||
|
def test_status_checks(self) -> None:
|
||||||
|
"""Test status check methods."""
|
||||||
|
pending = TaskContext(content="test", source="test", status=TaskStatus.PENDING)
|
||||||
|
completed = TaskContext(
|
||||||
|
content="test", source="test", status=TaskStatus.COMPLETED
|
||||||
|
)
|
||||||
|
blocked = TaskContext(content="test", source="test", status=TaskStatus.BLOCKED)
|
||||||
|
|
||||||
|
assert pending.is_active() is True
|
||||||
|
assert completed.is_complete() is True
|
||||||
|
assert blocked.is_blocked() is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolContext:
|
||||||
|
"""Tests for ToolContext."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test basic creation."""
|
||||||
|
ctx = ToolContext(
|
||||||
|
content="Tool result here",
|
||||||
|
source="tool:search",
|
||||||
|
tool_name="search",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.tool_name == "search"
|
||||||
|
assert ctx.get_type() == ContextType.TOOL
|
||||||
|
|
||||||
|
def test_from_tool_definition(self) -> None:
|
||||||
|
"""Test from_tool_definition factory method."""
|
||||||
|
ctx = ToolContext.from_tool_definition(
|
||||||
|
name="search_knowledge",
|
||||||
|
description="Search the knowledge base",
|
||||||
|
parameters={
|
||||||
|
"query": {"type": "string", "required": True},
|
||||||
|
"limit": {"type": "integer", "required": False},
|
||||||
|
},
|
||||||
|
server_name="knowledge-base",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.tool_name == "search_knowledge"
|
||||||
|
assert "Search the knowledge base" in ctx.content
|
||||||
|
assert ctx.is_result is False
|
||||||
|
assert ctx.server_name == "knowledge-base"
|
||||||
|
|
||||||
|
def test_from_tool_result(self) -> None:
|
||||||
|
"""Test from_tool_result factory method."""
|
||||||
|
ctx = ToolContext.from_tool_result(
|
||||||
|
tool_name="search",
|
||||||
|
result={"found": 5, "items": ["a", "b"]},
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
execution_time_ms=150.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.tool_name == "search"
|
||||||
|
assert ctx.is_result is True
|
||||||
|
assert ctx.result_status == ToolResultStatus.SUCCESS
|
||||||
|
assert "found" in ctx.content
|
||||||
|
|
||||||
|
def test_is_successful(self) -> None:
|
||||||
|
"""Test is_successful method."""
|
||||||
|
success = ToolContext.from_tool_result("test", "ok", ToolResultStatus.SUCCESS)
|
||||||
|
error = ToolContext.from_tool_result("test", "error", ToolResultStatus.ERROR)
|
||||||
|
|
||||||
|
assert success.is_successful() is True
|
||||||
|
assert error.is_successful() is False
|
||||||
|
|
||||||
|
def test_format_for_prompt(self) -> None:
|
||||||
|
"""Test format_for_prompt method."""
|
||||||
|
ctx = ToolContext.from_tool_result(
|
||||||
|
"search",
|
||||||
|
"Found 3 results",
|
||||||
|
ToolResultStatus.SUCCESS,
|
||||||
|
)
|
||||||
|
|
||||||
|
formatted = ctx.format_for_prompt()
|
||||||
|
assert "Tool Result" in formatted
|
||||||
|
assert "search" in formatted
|
||||||
|
assert "success" in formatted
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssembledContext:
|
||||||
|
"""Tests for AssembledContext."""
|
||||||
|
|
||||||
|
def test_creation(self) -> None:
|
||||||
|
"""Test basic creation."""
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="Assembled content here",
|
||||||
|
total_tokens=500,
|
||||||
|
context_count=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.content == "Assembled content here"
|
||||||
|
assert ctx.total_tokens == 500
|
||||||
|
assert ctx.context_count == 5
|
||||||
|
# Test backward compatibility aliases
|
||||||
|
assert ctx.token_count == 500
|
||||||
|
assert ctx.contexts_included == 5
|
||||||
|
|
||||||
|
def test_budget_utilization(self) -> None:
|
||||||
|
"""Test budget_utilization property."""
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="test",
|
||||||
|
total_tokens=800,
|
||||||
|
context_count=5,
|
||||||
|
budget_total=1000,
|
||||||
|
budget_used=800,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.budget_utilization == 0.8
|
||||||
|
|
||||||
|
def test_budget_utilization_zero_budget(self) -> None:
|
||||||
|
"""Test budget_utilization with zero budget."""
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="test",
|
||||||
|
total_tokens=0,
|
||||||
|
context_count=0,
|
||||||
|
budget_total=0,
|
||||||
|
budget_used=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.budget_utilization == 0.0
|
||||||
|
|
||||||
|
def test_to_dict(self) -> None:
|
||||||
|
"""Test to_dict method."""
|
||||||
|
ctx = AssembledContext(
|
||||||
|
content="test",
|
||||||
|
total_tokens=100,
|
||||||
|
context_count=2,
|
||||||
|
assembly_time_ms=50.123,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = ctx.to_dict()
|
||||||
|
assert data["content"] == "test"
|
||||||
|
assert data["total_tokens"] == 100
|
||||||
|
assert data["context_count"] == 2
|
||||||
|
assert data["assembly_time_ms"] == 50.12 # Rounded
|
||||||
|
|
||||||
|
def test_to_json_and_from_json(self) -> None:
|
||||||
|
"""Test JSON serialization round-trip."""
|
||||||
|
original = AssembledContext(
|
||||||
|
content="Test content",
|
||||||
|
total_tokens=100,
|
||||||
|
context_count=3,
|
||||||
|
excluded_count=2,
|
||||||
|
assembly_time_ms=45.5,
|
||||||
|
model="claude-3-sonnet",
|
||||||
|
budget_total=1000,
|
||||||
|
budget_used=100,
|
||||||
|
by_type={"system": 20, "knowledge": 80},
|
||||||
|
cache_hit=True,
|
||||||
|
cache_key="abc123",
|
||||||
|
)
|
||||||
|
|
||||||
|
json_str = original.to_json()
|
||||||
|
restored = AssembledContext.from_json(json_str)
|
||||||
|
|
||||||
|
assert restored.content == original.content
|
||||||
|
assert restored.total_tokens == original.total_tokens
|
||||||
|
assert restored.context_count == original.context_count
|
||||||
|
assert restored.excluded_count == original.excluded_count
|
||||||
|
assert restored.model == original.model
|
||||||
|
assert restored.cache_hit == original.cache_hit
|
||||||
|
assert restored.cache_key == original.cache_key
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseContextMethods:
|
||||||
|
"""Tests for BaseContext methods."""
|
||||||
|
|
||||||
|
def test_get_age_seconds(self) -> None:
|
||||||
|
"""Test get_age_seconds method."""
|
||||||
|
old_time = datetime.now(UTC) - timedelta(hours=2)
|
||||||
|
ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||||
|
|
||||||
|
age = ctx.get_age_seconds()
|
||||||
|
# Should be approximately 2 hours in seconds
|
||||||
|
assert 7100 < age < 7300
|
||||||
|
|
||||||
|
def test_get_age_hours(self) -> None:
|
||||||
|
"""Test get_age_hours method."""
|
||||||
|
old_time = datetime.now(UTC) - timedelta(hours=5)
|
||||||
|
ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||||
|
|
||||||
|
age = ctx.get_age_hours()
|
||||||
|
assert 4.9 < age < 5.1
|
||||||
|
|
||||||
|
def test_is_stale(self) -> None:
|
||||||
|
"""Test is_stale method."""
|
||||||
|
old_time = datetime.now(UTC) - timedelta(days=10)
|
||||||
|
new_time = datetime.now(UTC) - timedelta(hours=1)
|
||||||
|
|
||||||
|
old_ctx = SystemContext(content="test", source="test", timestamp=old_time)
|
||||||
|
new_ctx = SystemContext(content="test", source="test", timestamp=new_time)
|
||||||
|
|
||||||
|
# Default max_age is 168 hours (7 days)
|
||||||
|
assert old_ctx.is_stale() is True
|
||||||
|
assert new_ctx.is_stale() is False
|
||||||
|
|
||||||
|
def test_token_count_property(self) -> None:
|
||||||
|
"""Test token_count property."""
|
||||||
|
ctx = SystemContext(content="test", source="test")
|
||||||
|
|
||||||
|
# Initially None
|
||||||
|
assert ctx.token_count is None
|
||||||
|
|
||||||
|
# Can be set
|
||||||
|
ctx.token_count = 100
|
||||||
|
assert ctx.token_count == 100
|
||||||
|
|
||||||
|
def test_score_property_clamping(self) -> None:
|
||||||
|
"""Test that score is clamped to 0.0-1.0."""
|
||||||
|
ctx = SystemContext(content="test", source="test")
|
||||||
|
|
||||||
|
ctx.score = 1.5
|
||||||
|
assert ctx.score == 1.0
|
||||||
|
|
||||||
|
ctx.score = -0.5
|
||||||
|
assert ctx.score == 0.0
|
||||||
|
|
||||||
|
ctx.score = 0.75
|
||||||
|
assert ctx.score == 0.75
|
||||||
|
|
||||||
|
def test_hash_and_equality(self) -> None:
|
||||||
|
"""Test hash and equality based on ID."""
|
||||||
|
ctx1 = SystemContext(content="test", source="test")
|
||||||
|
ctx2 = SystemContext(content="test", source="test")
|
||||||
|
ctx3 = SystemContext(content="test", source="test")
|
||||||
|
ctx3.id = ctx1.id # Same ID as ctx1
|
||||||
|
|
||||||
|
# Different IDs = not equal
|
||||||
|
assert ctx1 != ctx2
|
||||||
|
|
||||||
|
# Same ID = equal
|
||||||
|
assert ctx1 == ctx3
|
||||||
|
|
||||||
|
# Can be used in sets
|
||||||
|
context_set = {ctx1, ctx2, ctx3}
|
||||||
|
assert len(context_set) == 2 # ctx1 and ctx3 are same
|
||||||
|
|
||||||
|
def test_truncate(self) -> None:
|
||||||
|
"""Test truncate method."""
|
||||||
|
long_content = "word " * 1000 # Long content
|
||||||
|
ctx = SystemContext(content=long_content, source="test")
|
||||||
|
ctx.token_count = 1000
|
||||||
|
|
||||||
|
truncated = ctx.truncate(100)
|
||||||
|
|
||||||
|
assert len(truncated) < len(long_content)
|
||||||
|
assert "[truncated]" in truncated
|
||||||
989
backend/tests/services/safety/test_audit.py
Normal file
989
backend/tests/services/safety/test_audit.py
Normal file
@@ -0,0 +1,989 @@
|
|||||||
|
"""
|
||||||
|
Tests for Audit Logger.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- AuditLogger initialization and lifecycle
|
||||||
|
- Event logging and hash chain
|
||||||
|
- Query and filtering
|
||||||
|
- Retention policy enforcement
|
||||||
|
- Handler management
|
||||||
|
- Singleton pattern
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.safety.audit.logger import (
|
||||||
|
AuditLogger,
|
||||||
|
get_audit_logger,
|
||||||
|
reset_audit_logger,
|
||||||
|
shutdown_audit_logger,
|
||||||
|
)
|
||||||
|
from app.services.safety.models import (
|
||||||
|
ActionMetadata,
|
||||||
|
ActionRequest,
|
||||||
|
ActionType,
|
||||||
|
AuditEventType,
|
||||||
|
AutonomyLevel,
|
||||||
|
SafetyDecision,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerInit:
|
||||||
|
"""Tests for AuditLogger initialization."""
|
||||||
|
|
||||||
|
def test_init_default_values(self):
|
||||||
|
"""Test initialization with default values."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
|
||||||
|
assert logger._flush_interval == 10.0
|
||||||
|
assert logger._enable_hash_chain is True
|
||||||
|
assert logger._last_hash is None
|
||||||
|
assert logger._running is False
|
||||||
|
|
||||||
|
def test_init_custom_values(self):
|
||||||
|
"""Test initialization with custom values."""
|
||||||
|
logger = AuditLogger(
|
||||||
|
max_buffer_size=500,
|
||||||
|
flush_interval_seconds=5.0,
|
||||||
|
enable_hash_chain=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert logger._flush_interval == 5.0
|
||||||
|
assert logger._enable_hash_chain is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerLifecycle:
|
||||||
|
"""Tests for AuditLogger start/stop."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_creates_flush_task(self):
|
||||||
|
"""Test that start creates the periodic flush task."""
|
||||||
|
logger = AuditLogger(flush_interval_seconds=1.0)
|
||||||
|
|
||||||
|
await logger.start()
|
||||||
|
|
||||||
|
assert logger._running is True
|
||||||
|
assert logger._flush_task is not None
|
||||||
|
|
||||||
|
await logger.stop()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_idempotent(self):
|
||||||
|
"""Test that multiple starts don't create multiple tasks."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
|
||||||
|
await logger.start()
|
||||||
|
task1 = logger._flush_task
|
||||||
|
|
||||||
|
await logger.start() # Second start
|
||||||
|
task2 = logger._flush_task
|
||||||
|
|
||||||
|
assert task1 is task2
|
||||||
|
|
||||||
|
await logger.stop()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_cancels_task_and_flushes(self):
|
||||||
|
"""Test that stop cancels the task and flushes events."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
|
||||||
|
await logger.start()
|
||||||
|
|
||||||
|
# Add an event
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED, agent_id="agent-1")
|
||||||
|
|
||||||
|
await logger.stop()
|
||||||
|
|
||||||
|
assert logger._running is False
|
||||||
|
# Event should be flushed
|
||||||
|
assert len(logger._persisted) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_without_start(self):
|
||||||
|
"""Test stopping without starting doesn't error."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
await logger.stop() # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerLog:
|
||||||
|
"""Tests for the log method."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def logger(self):
|
||||||
|
"""Create a logger instance."""
|
||||||
|
return AuditLogger(enable_hash_chain=True)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_creates_event(self, logger):
|
||||||
|
"""Test logging creates an event."""
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.ACTION_REQUESTED
|
||||||
|
assert event.agent_id == "agent-1"
|
||||||
|
assert event.project_id == "proj-1"
|
||||||
|
assert event.id is not None
|
||||||
|
assert event.timestamp is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_adds_hash_chain(self, logger):
|
||||||
|
"""Test logging adds hash chain."""
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "_hash" in event.details
|
||||||
|
assert "_prev_hash" in event.details
|
||||||
|
assert event.details["_prev_hash"] is None # First event
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_chain_links_events(self, logger):
|
||||||
|
"""Test hash chain links events."""
|
||||||
|
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
event2 = await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
assert event2.details["_prev_hash"] == event1.details["_hash"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_without_hash_chain(self):
|
||||||
|
"""Test logging without hash chain."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
event = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
assert "_hash" not in event.details
|
||||||
|
assert "_prev_hash" not in event.details
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_with_all_fields(self, logger):
|
||||||
|
"""Test logging with all optional fields."""
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
action_id="action-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
session_id="sess-1",
|
||||||
|
user_id="user-1",
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
details={"custom": "data"},
|
||||||
|
correlation_id="corr-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.agent_id == "agent-1"
|
||||||
|
assert event.action_id == "action-1"
|
||||||
|
assert event.project_id == "proj-1"
|
||||||
|
assert event.session_id == "sess-1"
|
||||||
|
assert event.user_id == "user-1"
|
||||||
|
assert event.decision == SafetyDecision.ALLOW
|
||||||
|
assert event.details["custom"] == "data"
|
||||||
|
assert event.correlation_id == "corr-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_buffers_event(self, logger):
|
||||||
|
"""Test logging adds event to buffer."""
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
assert len(logger._buffer) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerConvenienceMethods:
|
||||||
|
"""Tests for convenience logging methods."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def logger(self):
|
||||||
|
"""Create a logger instance."""
|
||||||
|
return AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
def action(self):
|
||||||
|
"""Create a test action request."""
|
||||||
|
metadata = ActionMetadata(
|
||||||
|
agent_id="agent-1",
|
||||||
|
session_id="sess-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
autonomy_level=AutonomyLevel.MILESTONE,
|
||||||
|
user_id="user-1",
|
||||||
|
correlation_id="corr-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ActionRequest(
|
||||||
|
action_type=ActionType.FILE_WRITE,
|
||||||
|
tool_name="file_write",
|
||||||
|
arguments={"path": "/test.txt"},
|
||||||
|
resource="/test.txt",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_action_request_allowed(self, logger, action):
|
||||||
|
"""Test logging allowed action request."""
|
||||||
|
event = await logger.log_action_request(
|
||||||
|
action,
|
||||||
|
SafetyDecision.ALLOW,
|
||||||
|
reasons=["Within budget"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.ACTION_VALIDATED
|
||||||
|
assert event.decision == SafetyDecision.ALLOW
|
||||||
|
assert event.details["reasons"] == ["Within budget"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_action_request_denied(self, logger, action):
|
||||||
|
"""Test logging denied action request."""
|
||||||
|
event = await logger.log_action_request(
|
||||||
|
action,
|
||||||
|
SafetyDecision.DENY,
|
||||||
|
reasons=["Rate limit exceeded"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.ACTION_DENIED
|
||||||
|
assert event.decision == SafetyDecision.DENY
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_action_executed_success(self, logger, action):
|
||||||
|
"""Test logging successful action execution."""
|
||||||
|
event = await logger.log_action_executed(
|
||||||
|
action,
|
||||||
|
success=True,
|
||||||
|
execution_time_ms=50.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.ACTION_EXECUTED
|
||||||
|
assert event.details["success"] is True
|
||||||
|
assert event.details["execution_time_ms"] == 50.0
|
||||||
|
assert event.details["error"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_action_executed_failure(self, logger, action):
|
||||||
|
"""Test logging failed action execution."""
|
||||||
|
event = await logger.log_action_executed(
|
||||||
|
action,
|
||||||
|
success=False,
|
||||||
|
execution_time_ms=100.0,
|
||||||
|
error="File not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.ACTION_FAILED
|
||||||
|
assert event.details["success"] is False
|
||||||
|
assert event.details["error"] == "File not found"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_approval_event(self, logger, action):
|
||||||
|
"""Test logging approval event."""
|
||||||
|
event = await logger.log_approval_event(
|
||||||
|
AuditEventType.APPROVAL_GRANTED,
|
||||||
|
approval_id="approval-1",
|
||||||
|
action=action,
|
||||||
|
decided_by="admin",
|
||||||
|
reason="Approved by admin",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.APPROVAL_GRANTED
|
||||||
|
assert event.details["approval_id"] == "approval-1"
|
||||||
|
assert event.details["decided_by"] == "admin"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_budget_event(self, logger):
|
||||||
|
"""Test logging budget event."""
|
||||||
|
event = await logger.log_budget_event(
|
||||||
|
AuditEventType.BUDGET_WARNING,
|
||||||
|
agent_id="agent-1",
|
||||||
|
scope="daily",
|
||||||
|
current_usage=8000.0,
|
||||||
|
limit=10000.0,
|
||||||
|
unit="tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.BUDGET_WARNING
|
||||||
|
assert event.details["scope"] == "daily"
|
||||||
|
assert event.details["usage_percent"] == 80.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_budget_event_zero_limit(self, logger):
|
||||||
|
"""Test logging budget event with zero limit."""
|
||||||
|
event = await logger.log_budget_event(
|
||||||
|
AuditEventType.BUDGET_WARNING,
|
||||||
|
agent_id="agent-1",
|
||||||
|
scope="daily",
|
||||||
|
current_usage=100.0,
|
||||||
|
limit=0.0, # Zero limit
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.details["usage_percent"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_emergency_stop(self, logger):
|
||||||
|
"""Test logging emergency stop."""
|
||||||
|
event = await logger.log_emergency_stop(
|
||||||
|
stop_type="global",
|
||||||
|
triggered_by="admin",
|
||||||
|
reason="Security incident",
|
||||||
|
affected_agents=["agent-1", "agent-2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.event_type == AuditEventType.EMERGENCY_STOP
|
||||||
|
assert event.details["stop_type"] == "global"
|
||||||
|
assert event.details["affected_agents"] == ["agent-1", "agent-2"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerFlush:
|
||||||
|
"""Tests for flush functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flush_persists_events(self):
|
||||||
|
"""Test flush moves events to persisted storage."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
assert len(logger._buffer) == 2
|
||||||
|
assert len(logger._persisted) == 0
|
||||||
|
|
||||||
|
count = await logger.flush()
|
||||||
|
|
||||||
|
assert count == 2
|
||||||
|
assert len(logger._buffer) == 0
|
||||||
|
assert len(logger._persisted) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_flush_empty_buffer(self):
|
||||||
|
"""Test flush with empty buffer."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
|
||||||
|
count = await logger.flush()
|
||||||
|
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerQuery:
|
||||||
|
"""Tests for query functionality."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def logger_with_events(self):
|
||||||
|
"""Create a logger with some test events."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
# Add various events
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_DENIED,
|
||||||
|
agent_id="agent-2",
|
||||||
|
project_id="proj-2",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.BUDGET_WARNING,
|
||||||
|
agent_id="agent-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
user_id="user-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
return logger
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_all(self, logger_with_events):
|
||||||
|
"""Test querying all events."""
|
||||||
|
events = await logger_with_events.query()
|
||||||
|
|
||||||
|
assert len(events) == 4
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_event_type(self, logger_with_events):
|
||||||
|
"""Test filtering by event type."""
|
||||||
|
events = await logger_with_events.query(
|
||||||
|
event_types=[AuditEventType.ACTION_REQUESTED]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].event_type == AuditEventType.ACTION_REQUESTED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_agent_id(self, logger_with_events):
|
||||||
|
"""Test filtering by agent ID."""
|
||||||
|
events = await logger_with_events.query(agent_id="agent-1")
|
||||||
|
|
||||||
|
assert len(events) == 3
|
||||||
|
assert all(e.agent_id == "agent-1" for e in events)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_project_id(self, logger_with_events):
|
||||||
|
"""Test filtering by project ID."""
|
||||||
|
events = await logger_with_events.query(project_id="proj-2")
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_user_id(self, logger_with_events):
|
||||||
|
"""Test filtering by user ID."""
|
||||||
|
events = await logger_with_events.query(user_id="user-1")
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].event_type == AuditEventType.BUDGET_WARNING
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_with_limit(self, logger_with_events):
|
||||||
|
"""Test query with limit."""
|
||||||
|
events = await logger_with_events.query(limit=2)
|
||||||
|
|
||||||
|
assert len(events) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_with_offset(self, logger_with_events):
|
||||||
|
"""Test query with offset."""
|
||||||
|
all_events = await logger_with_events.query()
|
||||||
|
offset_events = await logger_with_events.query(offset=2)
|
||||||
|
|
||||||
|
assert len(offset_events) == 2
|
||||||
|
assert offset_events[0] == all_events[2]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_time_range(self):
|
||||||
|
"""Test filtering by time range."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
now = datetime.utcnow()
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
# Query with start time
|
||||||
|
events = await logger.query(
|
||||||
|
start_time=now - timedelta(seconds=1),
|
||||||
|
end_time=now + timedelta(seconds=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_by_correlation_id(self):
|
||||||
|
"""Test filtering by correlation ID."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
correlation_id="corr-123",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
correlation_id="corr-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
events = await logger.query(correlation_id="corr-123")
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].correlation_id == "corr-123"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_combined_filters(self, logger_with_events):
|
||||||
|
"""Test combined filters."""
|
||||||
|
events = await logger_with_events.query(
|
||||||
|
agent_id="agent-1",
|
||||||
|
project_id="proj-1",
|
||||||
|
event_types=[
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(events) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_action_history(self, logger_with_events):
|
||||||
|
"""Test get_action_history method."""
|
||||||
|
events = await logger_with_events.get_action_history("agent-1")
|
||||||
|
|
||||||
|
# Should only return action-related events
|
||||||
|
assert len(events) == 2
|
||||||
|
assert all(
|
||||||
|
e.event_type
|
||||||
|
in {AuditEventType.ACTION_REQUESTED, AuditEventType.ACTION_EXECUTED}
|
||||||
|
for e in events
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerIntegrity:
|
||||||
|
"""Tests for hash chain integrity verification."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_integrity_valid(self):
|
||||||
|
"""Test integrity verification with valid chain."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=True)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
is_valid, issues = await logger.verify_integrity()
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert len(issues) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_integrity_disabled(self):
|
||||||
|
"""Test integrity verification when hash chain disabled."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
is_valid, issues = await logger.verify_integrity()
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert len(issues) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_verify_integrity_broken_chain(self):
|
||||||
|
"""Test integrity verification with broken chain."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=True)
|
||||||
|
|
||||||
|
event1 = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
# Tamper with first event's hash
|
||||||
|
event1.details["_hash"] = "tampered_hash"
|
||||||
|
|
||||||
|
is_valid, issues = await logger.verify_integrity()
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert len(issues) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerHandlers:
|
||||||
|
"""Tests for event handler management."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_sync_handler(self):
|
||||||
|
"""Test adding synchronous handler."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
events_received = []
|
||||||
|
|
||||||
|
def handler(event):
|
||||||
|
events_received.append(event)
|
||||||
|
|
||||||
|
logger.add_handler(handler)
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
assert len(events_received) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_async_handler(self):
|
||||||
|
"""Test adding async handler."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
events_received = []
|
||||||
|
|
||||||
|
async def handler(event):
|
||||||
|
events_received.append(event)
|
||||||
|
|
||||||
|
logger.add_handler(handler)
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
assert len(events_received) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_handler(self):
|
||||||
|
"""Test removing handler."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
events_received = []
|
||||||
|
|
||||||
|
def handler(event):
|
||||||
|
events_received.append(event)
|
||||||
|
|
||||||
|
logger.add_handler(handler)
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
logger.remove_handler(handler)
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
assert len(events_received) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handler_error_caught(self):
|
||||||
|
"""Test that handler errors are caught."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
def failing_handler(event):
|
||||||
|
raise ValueError("Handler error")
|
||||||
|
|
||||||
|
logger.add_handler(failing_handler)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
event = await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
assert event is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerSanitization:
|
||||||
|
"""Tests for sensitive data sanitization."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sanitize_sensitive_keys(self):
|
||||||
|
"""Test sanitization of sensitive keys."""
|
||||||
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.audit_retention_days = 30
|
||||||
|
mock_cfg.audit_include_sensitive = False
|
||||||
|
mock_config.return_value = mock_cfg
|
||||||
|
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
details={
|
||||||
|
"password": "secret123",
|
||||||
|
"api_key": "key123",
|
||||||
|
"token": "token123",
|
||||||
|
"normal_field": "visible",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.details["password"] == "[REDACTED]"
|
||||||
|
assert event.details["api_key"] == "[REDACTED]"
|
||||||
|
assert event.details["token"] == "[REDACTED]"
|
||||||
|
assert event.details["normal_field"] == "visible"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sanitize_nested_dict(self):
|
||||||
|
"""Test sanitization of nested dictionaries."""
|
||||||
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.audit_retention_days = 30
|
||||||
|
mock_cfg.audit_include_sensitive = False
|
||||||
|
mock_config.return_value = mock_cfg
|
||||||
|
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
details={
|
||||||
|
"config": {
|
||||||
|
"api_secret": "secret",
|
||||||
|
"name": "test",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.details["config"]["api_secret"] == "[REDACTED]"
|
||||||
|
assert event.details["config"]["name"] == "test"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_include_sensitive_when_enabled(self):
|
||||||
|
"""Test sensitive data included when enabled."""
|
||||||
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.audit_retention_days = 30
|
||||||
|
mock_cfg.audit_include_sensitive = True
|
||||||
|
mock_config.return_value = mock_cfg
|
||||||
|
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
details={"password": "secret123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.details["password"] == "secret123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerRetention:
|
||||||
|
"""Tests for retention policy enforcement."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retention_removes_old_events(self):
|
||||||
|
"""Test that retention removes old events."""
|
||||||
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.audit_retention_days = 7
|
||||||
|
mock_cfg.audit_include_sensitive = False
|
||||||
|
mock_config.return_value = mock_cfg
|
||||||
|
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
# Add an old event directly to persisted
|
||||||
|
from app.services.safety.models import AuditEvent
|
||||||
|
|
||||||
|
old_event = AuditEvent(
|
||||||
|
id="old-event",
|
||||||
|
event_type=AuditEventType.ACTION_REQUESTED,
|
||||||
|
timestamp=datetime.utcnow() - timedelta(days=10),
|
||||||
|
details={},
|
||||||
|
)
|
||||||
|
logger._persisted.append(old_event)
|
||||||
|
|
||||||
|
# Add a recent event
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
# Flush will trigger retention enforcement
|
||||||
|
await logger.flush()
|
||||||
|
|
||||||
|
# Old event should be removed
|
||||||
|
assert len(logger._persisted) == 1
|
||||||
|
assert logger._persisted[0].id != "old-event"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retention_keeps_recent_events(self):
|
||||||
|
"""Test that retention keeps recent events."""
|
||||||
|
with patch("app.services.safety.audit.logger.get_safety_config") as mock_config:
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.audit_retention_days = 7
|
||||||
|
mock_cfg.audit_include_sensitive = False
|
||||||
|
mock_config.return_value = mock_cfg
|
||||||
|
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
await logger.flush()
|
||||||
|
|
||||||
|
assert len(logger._persisted) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerSingleton:
|
||||||
|
"""Tests for singleton pattern."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_audit_logger_creates_instance(self):
|
||||||
|
"""Test get_audit_logger creates singleton."""
|
||||||
|
|
||||||
|
reset_audit_logger()
|
||||||
|
|
||||||
|
logger1 = await get_audit_logger()
|
||||||
|
logger2 = await get_audit_logger()
|
||||||
|
|
||||||
|
assert logger1 is logger2
|
||||||
|
|
||||||
|
await shutdown_audit_logger()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shutdown_audit_logger(self):
|
||||||
|
"""Test shutdown_audit_logger stops and clears singleton."""
|
||||||
|
import app.services.safety.audit.logger as audit_module
|
||||||
|
|
||||||
|
reset_audit_logger()
|
||||||
|
|
||||||
|
_logger = await get_audit_logger()
|
||||||
|
await shutdown_audit_logger()
|
||||||
|
|
||||||
|
assert audit_module._audit_logger is None
|
||||||
|
|
||||||
|
def test_reset_audit_logger(self):
|
||||||
|
"""Test reset_audit_logger clears singleton."""
|
||||||
|
import app.services.safety.audit.logger as audit_module
|
||||||
|
|
||||||
|
audit_module._audit_logger = AuditLogger()
|
||||||
|
reset_audit_logger()
|
||||||
|
|
||||||
|
assert audit_module._audit_logger is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerPeriodicFlush:
|
||||||
|
"""Tests for periodic flush background task."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_periodic_flush_runs(self):
|
||||||
|
"""Test periodic flush runs and flushes events."""
|
||||||
|
logger = AuditLogger(flush_interval_seconds=0.1, enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.start()
|
||||||
|
|
||||||
|
# Log an event
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
assert len(logger._buffer) == 1
|
||||||
|
|
||||||
|
# Wait for periodic flush
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
|
||||||
|
# Event should be flushed
|
||||||
|
assert len(logger._buffer) == 0
|
||||||
|
assert len(logger._persisted) == 1
|
||||||
|
|
||||||
|
await logger.stop()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_periodic_flush_handles_errors(self):
|
||||||
|
"""Test periodic flush handles errors gracefully."""
|
||||||
|
logger = AuditLogger(flush_interval_seconds=0.1)
|
||||||
|
|
||||||
|
await logger.start()
|
||||||
|
|
||||||
|
# Mock flush to raise an error
|
||||||
|
original_flush = logger.flush
|
||||||
|
|
||||||
|
async def failing_flush():
|
||||||
|
raise Exception("Flush error")
|
||||||
|
|
||||||
|
logger.flush = failing_flush
|
||||||
|
|
||||||
|
# Wait for flush attempt
|
||||||
|
await asyncio.sleep(0.15)
|
||||||
|
|
||||||
|
# Should still be running
|
||||||
|
assert logger._running is True
|
||||||
|
|
||||||
|
logger.flush = original_flush
|
||||||
|
await logger.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerLogging:
|
||||||
|
"""Tests for standard logger output."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_warning_for_denied(self):
|
||||||
|
"""Test warning level for denied events."""
|
||||||
|
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
||||||
|
audit_logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await audit_logger.log(
|
||||||
|
AuditEventType.ACTION_DENIED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.warning.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_error_for_failed(self):
|
||||||
|
"""Test error level for failed events."""
|
||||||
|
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
||||||
|
audit_logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await audit_logger.log(
|
||||||
|
AuditEventType.ACTION_FAILED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.error.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_info_for_normal(self):
|
||||||
|
"""Test info level for normal events."""
|
||||||
|
with patch("app.services.safety.audit.logger.logger") as mock_logger:
|
||||||
|
audit_logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await audit_logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_logger.info.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuditLoggerEdgeCases:
|
||||||
|
"""Tests for edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_with_none_details(self):
|
||||||
|
"""Test logging with None details."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
event = await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
details=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.details == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_with_action_id(self):
|
||||||
|
"""Test querying by action ID."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
action_id="action-1",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
action_id="action-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
events = await logger.query(action_id="action-1")
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].action_id == "action-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_with_session_id(self):
|
||||||
|
"""Test querying by session ID."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_REQUESTED,
|
||||||
|
session_id="sess-1",
|
||||||
|
)
|
||||||
|
await logger.log(
|
||||||
|
AuditEventType.ACTION_EXECUTED,
|
||||||
|
session_id="sess-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
events = await logger.query(session_id="sess-1")
|
||||||
|
|
||||||
|
assert len(events) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_includes_buffer_and_persisted(self):
|
||||||
|
"""Test query includes both buffer and persisted events."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
# Add event to buffer
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
# Flush to persisted
|
||||||
|
await logger.flush()
|
||||||
|
|
||||||
|
# Add another to buffer
|
||||||
|
await logger.log(AuditEventType.ACTION_EXECUTED)
|
||||||
|
|
||||||
|
# Query should return both
|
||||||
|
events = await logger.query()
|
||||||
|
|
||||||
|
assert len(events) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_nonexistent_handler(self):
|
||||||
|
"""Test removing handler that doesn't exist."""
|
||||||
|
logger = AuditLogger()
|
||||||
|
|
||||||
|
def handler(event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
logger.remove_handler(handler)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_time_filter_excludes_events(self):
|
||||||
|
"""Test time filters exclude events correctly."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
# Query with future start time
|
||||||
|
future = datetime.utcnow() + timedelta(hours=1)
|
||||||
|
events = await logger.query(start_time=future)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_end_time_filter(self):
|
||||||
|
"""Test end time filter."""
|
||||||
|
logger = AuditLogger(enable_hash_chain=False)
|
||||||
|
|
||||||
|
await logger.log(AuditEventType.ACTION_REQUESTED)
|
||||||
|
|
||||||
|
# Query with past end time
|
||||||
|
past = datetime.utcnow() - timedelta(hours=1)
|
||||||
|
events = await logger.query(end_time=past)
|
||||||
|
|
||||||
|
assert len(events) == 0
|
||||||
1136
backend/tests/services/safety/test_hitl.py
Normal file
1136
backend/tests/services/safety/test_hitl.py
Normal file
File diff suppressed because it is too large
Load Diff
874
backend/tests/services/safety/test_mcp_integration.py
Normal file
874
backend/tests/services/safety/test_mcp_integration.py
Normal file
@@ -0,0 +1,874 @@
|
|||||||
|
"""
|
||||||
|
Tests for MCP Safety Integration.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- MCPToolCall and MCPToolResult data structures
|
||||||
|
- MCPSafetyWrapper: tool registration, execution, safety checks
|
||||||
|
- Tool classification and action type mapping
|
||||||
|
- SafeToolExecutor context manager
|
||||||
|
- Factory function create_mcp_wrapper
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.safety.exceptions import EmergencyStopError
|
||||||
|
from app.services.safety.mcp.integration import (
|
||||||
|
MCPSafetyWrapper,
|
||||||
|
MCPToolCall,
|
||||||
|
MCPToolResult,
|
||||||
|
SafeToolExecutor,
|
||||||
|
create_mcp_wrapper,
|
||||||
|
)
|
||||||
|
from app.services.safety.models import (
|
||||||
|
ActionType,
|
||||||
|
AutonomyLevel,
|
||||||
|
SafetyDecision,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolCall:
|
||||||
|
"""Tests for MCPToolCall dataclass."""
|
||||||
|
|
||||||
|
def test_tool_call_creation(self):
|
||||||
|
"""Test creating a tool call."""
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_read",
|
||||||
|
arguments={"path": "/tmp/test.txt"}, # noqa: S108
|
||||||
|
server_name="file-server",
|
||||||
|
project_id="proj-1",
|
||||||
|
context={"session_id": "sess-1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert call.tool_name == "file_read"
|
||||||
|
assert call.arguments == {"path": "/tmp/test.txt"} # noqa: S108
|
||||||
|
assert call.server_name == "file-server"
|
||||||
|
assert call.project_id == "proj-1"
|
||||||
|
assert call.context == {"session_id": "sess-1"}
|
||||||
|
|
||||||
|
def test_tool_call_defaults(self):
|
||||||
|
"""Test tool call default values."""
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="test",
|
||||||
|
arguments={},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert call.server_name is None
|
||||||
|
assert call.project_id is None
|
||||||
|
assert call.context == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolResult:
|
||||||
|
"""Tests for MCPToolResult dataclass."""
|
||||||
|
|
||||||
|
def test_tool_result_success(self):
|
||||||
|
"""Test creating a successful result."""
|
||||||
|
result = MCPToolResult(
|
||||||
|
success=True,
|
||||||
|
result={"data": "test"},
|
||||||
|
safety_decision=SafetyDecision.ALLOW,
|
||||||
|
execution_time_ms=50.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.result == {"data": "test"}
|
||||||
|
assert result.error is None
|
||||||
|
assert result.safety_decision == SafetyDecision.ALLOW
|
||||||
|
assert result.execution_time_ms == 50.0
|
||||||
|
|
||||||
|
def test_tool_result_failure(self):
|
||||||
|
"""Test creating a failed result."""
|
||||||
|
result = MCPToolResult(
|
||||||
|
success=False,
|
||||||
|
error="Permission denied",
|
||||||
|
safety_decision=SafetyDecision.DENY,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.error == "Permission denied"
|
||||||
|
assert result.result is None
|
||||||
|
|
||||||
|
def test_tool_result_with_ids(self):
|
||||||
|
"""Test result with approval and checkpoint IDs."""
|
||||||
|
result = MCPToolResult(
|
||||||
|
success=True,
|
||||||
|
approval_id="approval-123",
|
||||||
|
checkpoint_id="checkpoint-456",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.approval_id == "approval-123"
|
||||||
|
assert result.checkpoint_id == "checkpoint-456"
|
||||||
|
|
||||||
|
def test_tool_result_defaults(self):
|
||||||
|
"""Test result default values."""
|
||||||
|
result = MCPToolResult(success=True)
|
||||||
|
|
||||||
|
assert result.result is None
|
||||||
|
assert result.error is None
|
||||||
|
assert result.safety_decision == SafetyDecision.ALLOW
|
||||||
|
assert result.execution_time_ms == 0.0
|
||||||
|
assert result.approval_id is None
|
||||||
|
assert result.checkpoint_id is None
|
||||||
|
assert result.metadata == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPSafetyWrapperClassification:
|
||||||
|
"""Tests for tool classification."""
|
||||||
|
|
||||||
|
def test_classify_file_read(self):
|
||||||
|
"""Test classifying file read tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("file_read") == ActionType.FILE_READ
|
||||||
|
assert wrapper._classify_tool("get_file") == ActionType.FILE_READ
|
||||||
|
assert wrapper._classify_tool("list_files") == ActionType.FILE_READ
|
||||||
|
assert wrapper._classify_tool("search_file") == ActionType.FILE_READ
|
||||||
|
|
||||||
|
def test_classify_file_write(self):
|
||||||
|
"""Test classifying file write tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("file_write") == ActionType.FILE_WRITE
|
||||||
|
assert wrapper._classify_tool("create_file") == ActionType.FILE_WRITE
|
||||||
|
assert wrapper._classify_tool("update_file") == ActionType.FILE_WRITE
|
||||||
|
|
||||||
|
def test_classify_file_delete(self):
|
||||||
|
"""Test classifying file delete tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("file_delete") == ActionType.FILE_DELETE
|
||||||
|
assert wrapper._classify_tool("remove_file") == ActionType.FILE_DELETE
|
||||||
|
|
||||||
|
def test_classify_database_read(self):
|
||||||
|
"""Test classifying database read tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("database_query") == ActionType.DATABASE_QUERY
|
||||||
|
assert wrapper._classify_tool("db_read") == ActionType.DATABASE_QUERY
|
||||||
|
assert wrapper._classify_tool("query_database") == ActionType.DATABASE_QUERY
|
||||||
|
|
||||||
|
def test_classify_database_mutate(self):
|
||||||
|
"""Test classifying database mutate tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("database_write") == ActionType.DATABASE_MUTATE
|
||||||
|
assert wrapper._classify_tool("db_update") == ActionType.DATABASE_MUTATE
|
||||||
|
assert wrapper._classify_tool("database_delete") == ActionType.DATABASE_MUTATE
|
||||||
|
|
||||||
|
def test_classify_shell_command(self):
|
||||||
|
"""Test classifying shell command tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("shell_execute") == ActionType.SHELL_COMMAND
|
||||||
|
assert wrapper._classify_tool("exec_command") == ActionType.SHELL_COMMAND
|
||||||
|
assert wrapper._classify_tool("bash_run") == ActionType.SHELL_COMMAND
|
||||||
|
|
||||||
|
def test_classify_git_operation(self):
|
||||||
|
"""Test classifying git tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("git_commit") == ActionType.GIT_OPERATION
|
||||||
|
assert wrapper._classify_tool("git_push") == ActionType.GIT_OPERATION
|
||||||
|
assert wrapper._classify_tool("git_status") == ActionType.GIT_OPERATION
|
||||||
|
|
||||||
|
def test_classify_network_request(self):
|
||||||
|
"""Test classifying network tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("http_get") == ActionType.NETWORK_REQUEST
|
||||||
|
assert wrapper._classify_tool("fetch_url") == ActionType.NETWORK_REQUEST
|
||||||
|
assert wrapper._classify_tool("api_request") == ActionType.NETWORK_REQUEST
|
||||||
|
|
||||||
|
def test_classify_llm_call(self):
|
||||||
|
"""Test classifying LLM tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("llm_generate") == ActionType.LLM_CALL
|
||||||
|
assert wrapper._classify_tool("ai_complete") == ActionType.LLM_CALL
|
||||||
|
assert wrapper._classify_tool("claude_chat") == ActionType.LLM_CALL
|
||||||
|
|
||||||
|
def test_classify_default(self):
|
||||||
|
"""Test default classification for unknown tools."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
assert wrapper._classify_tool("unknown_tool") == ActionType.TOOL_CALL
|
||||||
|
assert wrapper._classify_tool("custom_action") == ActionType.TOOL_CALL
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPSafetyWrapperToolHandlers:
|
||||||
|
"""Tests for tool handler registration."""
|
||||||
|
|
||||||
|
def test_register_tool_handler(self):
|
||||||
|
"""Test registering a tool handler."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
def handler(path: str) -> str:
|
||||||
|
return f"Read: {path}"
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("file_read", handler)
|
||||||
|
|
||||||
|
assert "file_read" in wrapper._tool_handlers
|
||||||
|
assert wrapper._tool_handlers["file_read"] is handler
|
||||||
|
|
||||||
|
def test_register_multiple_handlers(self):
|
||||||
|
"""Test registering multiple handlers."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool1", lambda: None)
|
||||||
|
wrapper.register_tool_handler("tool2", lambda: None)
|
||||||
|
wrapper.register_tool_handler("tool3", lambda: None)
|
||||||
|
|
||||||
|
assert len(wrapper._tool_handlers) == 3
|
||||||
|
|
||||||
|
def test_overwrite_handler(self):
|
||||||
|
"""Test overwriting a handler."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
handler1 = lambda: "first" # noqa: E731
|
||||||
|
handler2 = lambda: "second" # noqa: E731
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", handler1)
|
||||||
|
wrapper.register_tool_handler("tool", handler2)
|
||||||
|
|
||||||
|
assert wrapper._tool_handlers["tool"] is handler2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPSafetyWrapperExecution:
|
||||||
|
"""Tests for tool execution."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_guardian(self):
|
||||||
|
"""Create a mock SafetyGuardian."""
|
||||||
|
guardian = AsyncMock()
|
||||||
|
guardian.validate = AsyncMock()
|
||||||
|
return guardian
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_emergency(self):
|
||||||
|
"""Create a mock EmergencyControls."""
|
||||||
|
emergency = AsyncMock()
|
||||||
|
emergency.check_allowed = AsyncMock()
|
||||||
|
return emergency
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_allowed(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing an allowed tool call."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handler(path: str) -> dict:
|
||||||
|
return {"content": f"Data from {path}"}
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("file_read", handler)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_read",
|
||||||
|
arguments={"path": "/test.txt"},
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.result == {"content": "Data from /test.txt"}
|
||||||
|
assert result.safety_decision == SafetyDecision.ALLOW
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_denied(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing a denied tool call."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.DENY,
|
||||||
|
reasons=["Permission denied", "Rate limit exceeded"],
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_write",
|
||||||
|
arguments={"path": "/etc/passwd"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Permission denied" in result.error
|
||||||
|
assert "Rate limit exceeded" in result.error
|
||||||
|
assert result.safety_decision == SafetyDecision.DENY
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_requires_approval(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing a tool that requires approval."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.REQUIRE_APPROVAL,
|
||||||
|
reasons=["Destructive operation requires approval"],
|
||||||
|
approval_id="approval-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_delete",
|
||||||
|
arguments={"path": "/important.txt"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.safety_decision == SafetyDecision.REQUIRE_APPROVAL
|
||||||
|
assert result.approval_id == "approval-123"
|
||||||
|
assert "requires human approval" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_emergency_stop(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test execution blocked by emergency stop."""
|
||||||
|
mock_emergency.check_allowed.side_effect = EmergencyStopError(
|
||||||
|
"Emergency stop active"
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_write",
|
||||||
|
arguments={"path": "/test.txt"},
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.safety_decision == SafetyDecision.DENY
|
||||||
|
assert result.metadata.get("emergency_stop") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_bypass_safety(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing with safety bypass."""
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handler(data: str) -> str:
|
||||||
|
return f"Processed: {data}"
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("custom_tool", handler)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="custom_tool",
|
||||||
|
arguments={"data": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1", bypass_safety=True)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.result == "Processed: test"
|
||||||
|
# Guardian should not be called when bypassing
|
||||||
|
mock_guardian.validate.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_no_handler(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing a tool with no registered handler."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="unregistered_tool",
|
||||||
|
arguments={},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "No handler registered" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_handler_exception(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test handling exceptions from tool handler."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def failing_handler() -> None:
|
||||||
|
raise ValueError("Handler failed!")
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("failing_tool", failing_handler)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="failing_tool",
|
||||||
|
arguments={},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Handler failed!" in result.error
|
||||||
|
# Decision is still ALLOW because the safety check passed
|
||||||
|
assert result.safety_decision == SafetyDecision.ALLOW
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_sync_handler(self, mock_guardian, mock_emergency):
|
||||||
|
"""Test executing a synchronous handler."""
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
def sync_handler(value: int) -> int:
|
||||||
|
return value * 2
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("sync_tool", sync_handler)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="sync_tool",
|
||||||
|
arguments={"value": 21},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.result == 42
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildActionRequest:
|
||||||
|
"""Tests for _build_action_request."""
|
||||||
|
|
||||||
|
def test_build_action_request_basic(self):
|
||||||
|
"""Test building a basic action request."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="file_read",
|
||||||
|
arguments={"path": "/test.txt"},
|
||||||
|
project_id="proj-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
action = wrapper._build_action_request(call, "agent-1", AutonomyLevel.MILESTONE)
|
||||||
|
|
||||||
|
assert action.action_type == ActionType.FILE_READ
|
||||||
|
assert action.tool_name == "file_read"
|
||||||
|
assert action.arguments == {"path": "/test.txt"}
|
||||||
|
assert action.resource == "/test.txt"
|
||||||
|
assert action.metadata.agent_id == "agent-1"
|
||||||
|
assert action.metadata.project_id == "proj-1"
|
||||||
|
assert action.metadata.autonomy_level == AutonomyLevel.MILESTONE
|
||||||
|
|
||||||
|
def test_build_action_request_with_context(self):
|
||||||
|
"""Test building action request with session context."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="database_query",
|
||||||
|
arguments={"resource": "users", "query": "SELECT *"},
|
||||||
|
context={"session_id": "sess-123"},
|
||||||
|
project_id="proj-2",
|
||||||
|
)
|
||||||
|
|
||||||
|
action = wrapper._build_action_request(
|
||||||
|
call, "agent-2", AutonomyLevel.AUTONOMOUS
|
||||||
|
)
|
||||||
|
|
||||||
|
assert action.resource == "users"
|
||||||
|
assert action.metadata.session_id == "sess-123"
|
||||||
|
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||||
|
|
||||||
|
def test_build_action_request_no_resource(self):
|
||||||
|
"""Test building action request without resource."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="llm_generate",
|
||||||
|
arguments={"prompt": "Hello"},
|
||||||
|
)
|
||||||
|
|
||||||
|
action = wrapper._build_action_request(
|
||||||
|
call, "agent-1", AutonomyLevel.FULL_CONTROL
|
||||||
|
)
|
||||||
|
|
||||||
|
assert action.resource is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestElapsedTime:
|
||||||
|
"""Tests for _elapsed_ms helper."""
|
||||||
|
|
||||||
|
def test_elapsed_ms(self):
|
||||||
|
"""Test calculating elapsed time."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
start = datetime.utcnow() - timedelta(milliseconds=100)
|
||||||
|
elapsed = wrapper._elapsed_ms(start)
|
||||||
|
|
||||||
|
# Should be at least 100ms, but allow some tolerance
|
||||||
|
assert elapsed >= 99
|
||||||
|
assert elapsed < 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafeToolExecutor:
|
||||||
|
"""Tests for SafeToolExecutor context manager."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_executor_execute(self):
|
||||||
|
"""Test executing within context manager."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handler() -> str:
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("test_tool", handler)
|
||||||
|
|
||||||
|
call = MCPToolCall(tool_name="test_tool", arguments={})
|
||||||
|
|
||||||
|
async with SafeToolExecutor(wrapper, call, "agent-1") as executor:
|
||||||
|
result = await executor.execute()
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.result == "success"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_executor_result_property(self):
|
||||||
|
"""Test accessing result via property."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", lambda: "data")
|
||||||
|
|
||||||
|
call = MCPToolCall(tool_name="tool", arguments={})
|
||||||
|
executor = SafeToolExecutor(wrapper, call, "agent-1")
|
||||||
|
|
||||||
|
# Before execution
|
||||||
|
assert executor.result is None
|
||||||
|
|
||||||
|
async with executor:
|
||||||
|
await executor.execute()
|
||||||
|
|
||||||
|
# After execution
|
||||||
|
assert executor.result is not None
|
||||||
|
assert executor.result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_executor_with_autonomy_level(self):
|
||||||
|
"""Test executor with custom autonomy level."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", lambda: None)
|
||||||
|
|
||||||
|
call = MCPToolCall(tool_name="tool", arguments={})
|
||||||
|
|
||||||
|
async with SafeToolExecutor(
|
||||||
|
wrapper, call, "agent-1", AutonomyLevel.AUTONOMOUS
|
||||||
|
) as executor:
|
||||||
|
await executor.execute()
|
||||||
|
|
||||||
|
# Check that guardian was called with correct autonomy level
|
||||||
|
mock_guardian.validate.assert_called_once()
|
||||||
|
action = mock_guardian.validate.call_args[0][0]
|
||||||
|
assert action.metadata.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateMCPWrapper:
|
||||||
|
"""Tests for create_mcp_wrapper factory function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_wrapper_with_guardian(self):
|
||||||
|
"""Test creating wrapper with provided guardian."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||||
|
) as mock_get_emergency:
|
||||||
|
mock_get_emergency.return_value = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = await create_mcp_wrapper(guardian=mock_guardian)
|
||||||
|
|
||||||
|
assert wrapper._guardian is mock_guardian
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_wrapper_default_guardian(self):
|
||||||
|
"""Test creating wrapper with default guardian."""
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.safety.mcp.integration.get_safety_guardian"
|
||||||
|
) as mock_get_guardian,
|
||||||
|
patch(
|
||||||
|
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||||
|
) as mock_get_emergency,
|
||||||
|
):
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_get_guardian.return_value = mock_guardian
|
||||||
|
mock_get_emergency.return_value = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = await create_mcp_wrapper()
|
||||||
|
|
||||||
|
assert wrapper._guardian is mock_guardian
|
||||||
|
mock_get_guardian.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLazyGetters:
|
||||||
|
"""Tests for lazy getter methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_guardian_lazy(self):
|
||||||
|
"""Test lazy guardian initialization."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.safety.mcp.integration.get_safety_guardian"
|
||||||
|
) as mock_get:
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_get.return_value = mock_guardian
|
||||||
|
|
||||||
|
guardian = await wrapper._get_guardian()
|
||||||
|
|
||||||
|
assert guardian is mock_guardian
|
||||||
|
mock_get.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_guardian_cached(self):
|
||||||
|
"""Test guardian is cached after first access."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
wrapper = MCPSafetyWrapper(guardian=mock_guardian)
|
||||||
|
|
||||||
|
guardian = await wrapper._get_guardian()
|
||||||
|
|
||||||
|
assert guardian is mock_guardian
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_emergency_controls_lazy(self):
|
||||||
|
"""Test lazy emergency controls initialization."""
|
||||||
|
wrapper = MCPSafetyWrapper()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.safety.mcp.integration.get_emergency_controls"
|
||||||
|
) as mock_get:
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
mock_get.return_value = mock_emergency
|
||||||
|
|
||||||
|
emergency = await wrapper._get_emergency_controls()
|
||||||
|
|
||||||
|
assert emergency is mock_emergency
|
||||||
|
mock_get.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_emergency_controls_cached(self):
|
||||||
|
"""Test emergency controls is cached after first access."""
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
wrapper = MCPSafetyWrapper(emergency_controls=mock_emergency)
|
||||||
|
|
||||||
|
emergency = await wrapper._get_emergency_controls()
|
||||||
|
|
||||||
|
assert emergency is mock_emergency
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Tests for edge cases and error handling."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_safety_error(self):
|
||||||
|
"""Test handling SafetyError from guardian."""
|
||||||
|
from app.services.safety.exceptions import SafetyError
|
||||||
|
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.side_effect = SafetyError("Internal safety error")
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
call = MCPToolCall(tool_name="test", arguments={})
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert "Internal safety error" in result.error
|
||||||
|
assert result.safety_decision == SafetyDecision.DENY
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_checkpoint_id(self):
|
||||||
|
"""Test that checkpoint_id is propagated to result."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id="checkpoint-abc",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", lambda: "result")
|
||||||
|
|
||||||
|
call = MCPToolCall(tool_name="tool", arguments={})
|
||||||
|
|
||||||
|
result = await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.checkpoint_id == "checkpoint-abc"
|
||||||
|
|
||||||
|
def test_destructive_tools_constant(self):
|
||||||
|
"""Test DESTRUCTIVE_TOOLS class constant."""
|
||||||
|
assert "file_write" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||||
|
assert "file_delete" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||||
|
assert "shell_execute" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||||
|
assert "git_push" in MCPSafetyWrapper.DESTRUCTIVE_TOOLS
|
||||||
|
|
||||||
|
def test_read_only_tools_constant(self):
|
||||||
|
"""Test READ_ONLY_TOOLS class constant."""
|
||||||
|
assert "file_read" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||||
|
assert "database_query" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||||
|
assert "git_status" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||||
|
assert "search" in MCPSafetyWrapper.READ_ONLY_TOOLS
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scope_with_project_id(self):
|
||||||
|
"""Test that scope is set correctly with project_id."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", lambda: None)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="tool",
|
||||||
|
arguments={},
|
||||||
|
project_id="proj-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
await wrapper.execute(call, "agent-1")
|
||||||
|
|
||||||
|
# Verify emergency check was called with project scope
|
||||||
|
mock_emergency.check_allowed.assert_called_once()
|
||||||
|
call_kwargs = mock_emergency.check_allowed.call_args
|
||||||
|
assert "project:proj-123" in str(call_kwargs)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_scope_without_project_id(self):
|
||||||
|
"""Test that scope falls back to agent when no project_id."""
|
||||||
|
mock_guardian = AsyncMock()
|
||||||
|
mock_guardian.validate.return_value = MagicMock(
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
reasons=[],
|
||||||
|
approval_id=None,
|
||||||
|
checkpoint_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_emergency = AsyncMock()
|
||||||
|
|
||||||
|
wrapper = MCPSafetyWrapper(
|
||||||
|
guardian=mock_guardian,
|
||||||
|
emergency_controls=mock_emergency,
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper.register_tool_handler("tool", lambda: None)
|
||||||
|
|
||||||
|
call = MCPToolCall(
|
||||||
|
tool_name="tool",
|
||||||
|
arguments={},
|
||||||
|
# No project_id
|
||||||
|
)
|
||||||
|
|
||||||
|
await wrapper.execute(call, "agent-555")
|
||||||
|
|
||||||
|
# Verify emergency check was called with agent scope
|
||||||
|
mock_emergency.check_allowed.assert_called_once()
|
||||||
|
call_kwargs = mock_emergency.check_allowed.call_args
|
||||||
|
assert "agent:agent-555" in str(call_kwargs)
|
||||||
747
backend/tests/services/safety/test_metrics.py
Normal file
747
backend/tests/services/safety/test_metrics.py
Normal file
@@ -0,0 +1,747 @@
|
|||||||
|
"""
|
||||||
|
Tests for Safety Metrics Collector.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- MetricType, MetricValue, HistogramBucket data structures
|
||||||
|
- SafetyMetrics counters, gauges, histograms
|
||||||
|
- Prometheus format export
|
||||||
|
- Summary and reset operations
|
||||||
|
- Singleton pattern and convenience functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.safety.metrics.collector import (
|
||||||
|
HistogramBucket,
|
||||||
|
MetricType,
|
||||||
|
MetricValue,
|
||||||
|
SafetyMetrics,
|
||||||
|
get_safety_metrics,
|
||||||
|
record_mcp_call,
|
||||||
|
record_validation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricType:
|
||||||
|
"""Tests for MetricType enum."""
|
||||||
|
|
||||||
|
def test_metric_types_exist(self):
|
||||||
|
"""Test all metric types are defined."""
|
||||||
|
assert MetricType.COUNTER == "counter"
|
||||||
|
assert MetricType.GAUGE == "gauge"
|
||||||
|
assert MetricType.HISTOGRAM == "histogram"
|
||||||
|
|
||||||
|
def test_metric_type_is_string(self):
|
||||||
|
"""Test MetricType values are strings."""
|
||||||
|
assert isinstance(MetricType.COUNTER.value, str)
|
||||||
|
assert isinstance(MetricType.GAUGE.value, str)
|
||||||
|
assert isinstance(MetricType.HISTOGRAM.value, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetricValue:
|
||||||
|
"""Tests for MetricValue dataclass."""
|
||||||
|
|
||||||
|
def test_metric_value_creation(self):
|
||||||
|
"""Test creating a metric value."""
|
||||||
|
mv = MetricValue(
|
||||||
|
name="test_metric",
|
||||||
|
metric_type=MetricType.COUNTER,
|
||||||
|
value=42.0,
|
||||||
|
labels={"env": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mv.name == "test_metric"
|
||||||
|
assert mv.metric_type == MetricType.COUNTER
|
||||||
|
assert mv.value == 42.0
|
||||||
|
assert mv.labels == {"env": "test"}
|
||||||
|
assert mv.timestamp is not None
|
||||||
|
|
||||||
|
def test_metric_value_defaults(self):
|
||||||
|
"""Test metric value default values."""
|
||||||
|
mv = MetricValue(
|
||||||
|
name="test",
|
||||||
|
metric_type=MetricType.GAUGE,
|
||||||
|
value=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mv.labels == {}
|
||||||
|
assert mv.timestamp is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestHistogramBucket:
|
||||||
|
"""Tests for HistogramBucket dataclass."""
|
||||||
|
|
||||||
|
def test_histogram_bucket_creation(self):
|
||||||
|
"""Test creating a histogram bucket."""
|
||||||
|
bucket = HistogramBucket(le=0.5, count=10)
|
||||||
|
|
||||||
|
assert bucket.le == 0.5
|
||||||
|
assert bucket.count == 10
|
||||||
|
|
||||||
|
def test_histogram_bucket_defaults(self):
|
||||||
|
"""Test histogram bucket default count."""
|
||||||
|
bucket = HistogramBucket(le=1.0)
|
||||||
|
|
||||||
|
assert bucket.le == 1.0
|
||||||
|
assert bucket.count == 0
|
||||||
|
|
||||||
|
def test_histogram_bucket_infinity(self):
|
||||||
|
"""Test histogram bucket with infinity."""
|
||||||
|
bucket = HistogramBucket(le=float("inf"))
|
||||||
|
|
||||||
|
assert bucket.le == float("inf")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyMetricsCounters:
|
||||||
|
"""Tests for SafetyMetrics counter methods."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def metrics(self):
|
||||||
|
"""Create fresh metrics instance."""
|
||||||
|
return SafetyMetrics()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_validations(self, metrics):
|
||||||
|
"""Test incrementing validations counter."""
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
await metrics.inc_validations("deny", agent_id="agent-1")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["total_validations"] == 3
|
||||||
|
assert summary["denied_validations"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_approvals_requested(self, metrics):
|
||||||
|
"""Test incrementing approval requests counter."""
|
||||||
|
await metrics.inc_approvals_requested("normal")
|
||||||
|
await metrics.inc_approvals_requested("urgent")
|
||||||
|
await metrics.inc_approvals_requested() # default
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["approval_requests"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_approvals_granted(self, metrics):
|
||||||
|
"""Test incrementing approvals granted counter."""
|
||||||
|
await metrics.inc_approvals_granted()
|
||||||
|
await metrics.inc_approvals_granted()
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["approvals_granted"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_approvals_denied(self, metrics):
|
||||||
|
"""Test incrementing approvals denied counter."""
|
||||||
|
await metrics.inc_approvals_denied("timeout")
|
||||||
|
await metrics.inc_approvals_denied("policy")
|
||||||
|
await metrics.inc_approvals_denied() # default manual
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["approvals_denied"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_rate_limit_exceeded(self, metrics):
|
||||||
|
"""Test incrementing rate limit exceeded counter."""
|
||||||
|
await metrics.inc_rate_limit_exceeded("requests_per_minute")
|
||||||
|
await metrics.inc_rate_limit_exceeded("tokens_per_hour")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["rate_limit_hits"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_budget_exceeded(self, metrics):
|
||||||
|
"""Test incrementing budget exceeded counter."""
|
||||||
|
await metrics.inc_budget_exceeded("daily_cost")
|
||||||
|
await metrics.inc_budget_exceeded("monthly_tokens")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["budget_exceeded"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_loops_detected(self, metrics):
|
||||||
|
"""Test incrementing loops detected counter."""
|
||||||
|
await metrics.inc_loops_detected("repetition")
|
||||||
|
await metrics.inc_loops_detected("pattern")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["loops_detected"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_emergency_events(self, metrics):
|
||||||
|
"""Test incrementing emergency events counter."""
|
||||||
|
await metrics.inc_emergency_events("pause", "project-1")
|
||||||
|
await metrics.inc_emergency_events("stop", "agent-2")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["emergency_events"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_content_filtered(self, metrics):
|
||||||
|
"""Test incrementing content filtered counter."""
|
||||||
|
await metrics.inc_content_filtered("profanity", "blocked")
|
||||||
|
await metrics.inc_content_filtered("pii", "redacted")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["content_filtered"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_checkpoints_created(self, metrics):
|
||||||
|
"""Test incrementing checkpoints created counter."""
|
||||||
|
await metrics.inc_checkpoints_created()
|
||||||
|
await metrics.inc_checkpoints_created()
|
||||||
|
await metrics.inc_checkpoints_created()
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["checkpoints_created"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_rollbacks_executed(self, metrics):
|
||||||
|
"""Test incrementing rollbacks executed counter."""
|
||||||
|
await metrics.inc_rollbacks_executed(success=True)
|
||||||
|
await metrics.inc_rollbacks_executed(success=False)
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["rollbacks_executed"] == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inc_mcp_calls(self, metrics):
|
||||||
|
"""Test incrementing MCP calls counter."""
|
||||||
|
await metrics.inc_mcp_calls("search_knowledge", success=True)
|
||||||
|
await metrics.inc_mcp_calls("run_code", success=False)
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["mcp_calls"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyMetricsGauges:
|
||||||
|
"""Tests for SafetyMetrics gauge methods."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def metrics(self):
|
||||||
|
"""Create fresh metrics instance."""
|
||||||
|
return SafetyMetrics()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_budget_remaining(self, metrics):
|
||||||
|
"""Test setting budget remaining gauge."""
|
||||||
|
await metrics.set_budget_remaining("project-1", "daily_cost", 50.0)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
gauge_metrics = [m for m in all_metrics if m.name == "safety_budget_remaining"]
|
||||||
|
assert len(gauge_metrics) == 1
|
||||||
|
assert gauge_metrics[0].value == 50.0
|
||||||
|
assert gauge_metrics[0].labels["scope"] == "project-1"
|
||||||
|
assert gauge_metrics[0].labels["budget_type"] == "daily_cost"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_rate_limit_remaining(self, metrics):
|
||||||
|
"""Test setting rate limit remaining gauge."""
|
||||||
|
await metrics.set_rate_limit_remaining("agent-1", "requests_per_minute", 45)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
gauge_metrics = [
|
||||||
|
m for m in all_metrics if m.name == "safety_rate_limit_remaining"
|
||||||
|
]
|
||||||
|
assert len(gauge_metrics) == 1
|
||||||
|
assert gauge_metrics[0].value == 45.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_pending_approvals(self, metrics):
|
||||||
|
"""Test setting pending approvals gauge."""
|
||||||
|
await metrics.set_pending_approvals(5)
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["pending_approvals"] == 5
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_active_checkpoints(self, metrics):
|
||||||
|
"""Test setting active checkpoints gauge."""
|
||||||
|
await metrics.set_active_checkpoints(3)
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["active_checkpoints"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_emergency_state(self, metrics):
|
||||||
|
"""Test setting emergency state gauge."""
|
||||||
|
await metrics.set_emergency_state("project-1", "normal")
|
||||||
|
await metrics.set_emergency_state("project-2", "paused")
|
||||||
|
await metrics.set_emergency_state("project-3", "stopped")
|
||||||
|
await metrics.set_emergency_state("project-4", "unknown")
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
state_metrics = [m for m in all_metrics if m.name == "safety_emergency_state"]
|
||||||
|
assert len(state_metrics) == 4
|
||||||
|
|
||||||
|
# Check state values
|
||||||
|
values_by_scope = {m.labels["scope"]: m.value for m in state_metrics}
|
||||||
|
assert values_by_scope["project-1"] == 0.0 # normal
|
||||||
|
assert values_by_scope["project-2"] == 1.0 # paused
|
||||||
|
assert values_by_scope["project-3"] == 2.0 # stopped
|
||||||
|
assert values_by_scope["project-4"] == -1.0 # unknown
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyMetricsHistograms:
|
||||||
|
"""Tests for SafetyMetrics histogram methods."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def metrics(self):
|
||||||
|
"""Create fresh metrics instance."""
|
||||||
|
return SafetyMetrics()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_observe_validation_latency(self, metrics):
|
||||||
|
"""Test observing validation latency."""
|
||||||
|
await metrics.observe_validation_latency(0.05)
|
||||||
|
await metrics.observe_validation_latency(0.15)
|
||||||
|
await metrics.observe_validation_latency(0.5)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
|
||||||
|
count_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "validation_latency_seconds_count"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert count_metric is not None
|
||||||
|
assert count_metric.value == 3.0
|
||||||
|
|
||||||
|
sum_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "validation_latency_seconds_sum"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert sum_metric is not None
|
||||||
|
assert abs(sum_metric.value - 0.7) < 0.001
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_observe_approval_latency(self, metrics):
|
||||||
|
"""Test observing approval latency."""
|
||||||
|
await metrics.observe_approval_latency(1.5)
|
||||||
|
await metrics.observe_approval_latency(3.0)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
|
||||||
|
count_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "approval_latency_seconds_count"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert count_metric is not None
|
||||||
|
assert count_metric.value == 2.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_observe_mcp_execution_latency(self, metrics):
|
||||||
|
"""Test observing MCP execution latency."""
|
||||||
|
await metrics.observe_mcp_execution_latency(0.02)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
|
||||||
|
count_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "mcp_execution_latency_seconds_count"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert count_metric is not None
|
||||||
|
assert count_metric.value == 1.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_histogram_bucket_updates(self, metrics):
|
||||||
|
"""Test that histogram buckets are updated correctly."""
|
||||||
|
# Add values to test bucket distribution
|
||||||
|
await metrics.observe_validation_latency(0.005) # <= 0.01
|
||||||
|
await metrics.observe_validation_latency(0.03) # <= 0.05
|
||||||
|
await metrics.observe_validation_latency(0.07) # <= 0.1
|
||||||
|
await metrics.observe_validation_latency(15.0) # <= inf
|
||||||
|
|
||||||
|
prometheus = await metrics.get_prometheus_format()
|
||||||
|
|
||||||
|
# Check that bucket counts are in output
|
||||||
|
assert "validation_latency_seconds_bucket" in prometheus
|
||||||
|
assert "le=" in prometheus
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyMetricsExport:
|
||||||
|
"""Tests for SafetyMetrics export methods."""
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def metrics(self):
|
||||||
|
"""Create fresh metrics instance with some data."""
|
||||||
|
m = SafetyMetrics()
|
||||||
|
|
||||||
|
# Add some counters
|
||||||
|
await m.inc_validations("allow")
|
||||||
|
await m.inc_validations("deny", agent_id="agent-1")
|
||||||
|
|
||||||
|
# Add some gauges
|
||||||
|
await m.set_pending_approvals(3)
|
||||||
|
await m.set_budget_remaining("proj-1", "daily", 100.0)
|
||||||
|
|
||||||
|
# Add some histogram values
|
||||||
|
await m.observe_validation_latency(0.1)
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_all_metrics(self, metrics):
|
||||||
|
"""Test getting all metrics."""
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
|
||||||
|
assert len(all_metrics) > 0
|
||||||
|
assert all(isinstance(m, MetricValue) for m in all_metrics)
|
||||||
|
|
||||||
|
# Check we have different types
|
||||||
|
types = {m.metric_type for m in all_metrics}
|
||||||
|
assert MetricType.COUNTER in types
|
||||||
|
assert MetricType.GAUGE in types
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_prometheus_format(self, metrics):
|
||||||
|
"""Test Prometheus format export."""
|
||||||
|
output = await metrics.get_prometheus_format()
|
||||||
|
|
||||||
|
assert isinstance(output, str)
|
||||||
|
assert "# TYPE" in output
|
||||||
|
assert "counter" in output
|
||||||
|
assert "gauge" in output
|
||||||
|
assert "safety_validations_total" in output
|
||||||
|
assert "safety_pending_approvals" in output
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prometheus_format_with_labels(self, metrics):
|
||||||
|
"""Test Prometheus format includes labels correctly."""
|
||||||
|
output = await metrics.get_prometheus_format()
|
||||||
|
|
||||||
|
# Counter with labels
|
||||||
|
assert "decision=allow" in output or "decision=deny" in output
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prometheus_format_histogram_buckets(self, metrics):
|
||||||
|
"""Test Prometheus format includes histogram buckets."""
|
||||||
|
output = await metrics.get_prometheus_format()
|
||||||
|
|
||||||
|
assert "histogram" in output
|
||||||
|
assert "_bucket" in output
|
||||||
|
assert "le=" in output
|
||||||
|
assert "+Inf" in output
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_summary(self, metrics):
|
||||||
|
"""Test getting summary."""
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
|
||||||
|
assert "total_validations" in summary
|
||||||
|
assert "denied_validations" in summary
|
||||||
|
assert "approval_requests" in summary
|
||||||
|
assert "pending_approvals" in summary
|
||||||
|
assert "active_checkpoints" in summary
|
||||||
|
|
||||||
|
assert summary["total_validations"] == 2
|
||||||
|
assert summary["denied_validations"] == 1
|
||||||
|
assert summary["pending_approvals"] == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_summary_empty_counters(self):
|
||||||
|
"""Test summary with no data."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
|
||||||
|
assert summary["total_validations"] == 0
|
||||||
|
assert summary["denied_validations"] == 0
|
||||||
|
assert summary["pending_approvals"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSafetyMetricsReset:
|
||||||
|
"""Tests for SafetyMetrics reset."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_clears_counters(self):
|
||||||
|
"""Test reset clears all counters."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
await metrics.inc_approvals_granted()
|
||||||
|
await metrics.set_pending_approvals(5)
|
||||||
|
await metrics.observe_validation_latency(0.1)
|
||||||
|
|
||||||
|
await metrics.reset()
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["total_validations"] == 0
|
||||||
|
assert summary["approvals_granted"] == 0
|
||||||
|
assert summary["pending_approvals"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_reinitializes_histogram_buckets(self):
|
||||||
|
"""Test reset reinitializes histogram buckets."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.observe_validation_latency(0.1)
|
||||||
|
await metrics.reset()
|
||||||
|
|
||||||
|
# After reset, histogram buckets should be reinitialized
|
||||||
|
prometheus = await metrics.get_prometheus_format()
|
||||||
|
assert "validation_latency_seconds" in prometheus
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseLabels:
|
||||||
|
"""Tests for _parse_labels helper method."""
|
||||||
|
|
||||||
|
def test_parse_empty_labels(self):
|
||||||
|
"""Test parsing empty labels string."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels("")
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_parse_single_label(self):
|
||||||
|
"""Test parsing single label."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels("key=value")
|
||||||
|
assert result == {"key": "value"}
|
||||||
|
|
||||||
|
def test_parse_multiple_labels(self):
|
||||||
|
"""Test parsing multiple labels."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels("a=1,b=2,c=3")
|
||||||
|
assert result == {"a": "1", "b": "2", "c": "3"}
|
||||||
|
|
||||||
|
def test_parse_labels_with_spaces(self):
|
||||||
|
"""Test parsing labels with spaces."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels(" key = value , foo = bar ")
|
||||||
|
assert result == {"key": "value", "foo": "bar"}
|
||||||
|
|
||||||
|
def test_parse_labels_with_equals_in_value(self):
|
||||||
|
"""Test parsing labels with = in value."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels("query=a=b")
|
||||||
|
assert result == {"query": "a=b"}
|
||||||
|
|
||||||
|
def test_parse_invalid_label(self):
|
||||||
|
"""Test parsing invalid label without equals."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
result = metrics._parse_labels("no_equals")
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestHistogramBucketInit:
|
||||||
|
"""Tests for histogram bucket initialization."""
|
||||||
|
|
||||||
|
def test_histogram_buckets_initialized(self):
|
||||||
|
"""Test that histogram buckets are initialized."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
assert "validation_latency_seconds" in metrics._histogram_buckets
|
||||||
|
assert "approval_latency_seconds" in metrics._histogram_buckets
|
||||||
|
assert "mcp_execution_latency_seconds" in metrics._histogram_buckets
|
||||||
|
|
||||||
|
def test_histogram_buckets_have_correct_values(self):
|
||||||
|
"""Test histogram buckets have correct boundary values."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
buckets = metrics._histogram_buckets["validation_latency_seconds"]
|
||||||
|
|
||||||
|
# Check first few and last bucket
|
||||||
|
assert buckets[0].le == 0.01
|
||||||
|
assert buckets[1].le == 0.05
|
||||||
|
assert buckets[-1].le == float("inf")
|
||||||
|
|
||||||
|
# Check all have zero initial count
|
||||||
|
assert all(b.count == 0 for b in buckets)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingletonAndConvenience:
|
||||||
|
"""Tests for singleton pattern and convenience functions."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_safety_metrics_returns_same_instance(self):
|
||||||
|
"""Test get_safety_metrics returns singleton."""
|
||||||
|
# Reset the module-level singleton for this test
|
||||||
|
import app.services.safety.metrics.collector as collector_module
|
||||||
|
|
||||||
|
collector_module._metrics = None
|
||||||
|
|
||||||
|
m1 = await get_safety_metrics()
|
||||||
|
m2 = await get_safety_metrics()
|
||||||
|
|
||||||
|
assert m1 is m2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_validation_convenience(self):
|
||||||
|
"""Test record_validation convenience function."""
|
||||||
|
import app.services.safety.metrics.collector as collector_module
|
||||||
|
|
||||||
|
collector_module._metrics = None # Reset
|
||||||
|
|
||||||
|
await record_validation("allow")
|
||||||
|
await record_validation("deny", agent_id="test-agent")
|
||||||
|
|
||||||
|
metrics = await get_safety_metrics()
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
|
||||||
|
assert summary["total_validations"] == 2
|
||||||
|
assert summary["denied_validations"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_mcp_call_convenience(self):
|
||||||
|
"""Test record_mcp_call convenience function."""
|
||||||
|
import app.services.safety.metrics.collector as collector_module
|
||||||
|
|
||||||
|
collector_module._metrics = None # Reset
|
||||||
|
|
||||||
|
await record_mcp_call("search_knowledge", success=True, latency_ms=50)
|
||||||
|
await record_mcp_call("run_code", success=False, latency_ms=100)
|
||||||
|
|
||||||
|
metrics = await get_safety_metrics()
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
|
||||||
|
assert summary["mcp_calls"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestConcurrency:
|
||||||
|
"""Tests for concurrent metric updates."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_counter_increments(self):
|
||||||
|
"""Test concurrent counter increments are safe."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
async def increment_many():
|
||||||
|
for _ in range(100):
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
|
||||||
|
# Run 10 concurrent tasks each incrementing 100 times
|
||||||
|
await asyncio.gather(*[increment_many() for _ in range(10)])
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["total_validations"] == 1000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_gauge_updates(self):
|
||||||
|
"""Test concurrent gauge updates are safe."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
async def update_gauge(value):
|
||||||
|
await metrics.set_pending_approvals(value)
|
||||||
|
|
||||||
|
# Run concurrent gauge updates
|
||||||
|
await asyncio.gather(*[update_gauge(i) for i in range(100)])
|
||||||
|
|
||||||
|
# Final value should be one of the updates (last one wins)
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert 0 <= summary["pending_approvals"] < 100
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_histogram_observations(self):
|
||||||
|
"""Test concurrent histogram observations are safe."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
async def observe_many():
|
||||||
|
for i in range(100):
|
||||||
|
await metrics.observe_validation_latency(i / 1000)
|
||||||
|
|
||||||
|
await asyncio.gather(*[observe_many() for _ in range(10)])
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
count_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "validation_latency_seconds_count"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert count_metric is not None
|
||||||
|
assert count_metric.value == 1000.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Tests for edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_very_large_counter_value(self):
|
||||||
|
"""Test handling very large counter values."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
for _ in range(10000):
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["total_validations"] == 10000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_zero_and_negative_gauge_values(self):
|
||||||
|
"""Test zero and negative gauge values."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.set_budget_remaining("project", "cost", 0.0)
|
||||||
|
await metrics.set_budget_remaining("project2", "cost", -10.0)
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
gauges = [m for m in all_metrics if m.name == "safety_budget_remaining"]
|
||||||
|
|
||||||
|
values = {m.labels.get("scope"): m.value for m in gauges}
|
||||||
|
assert values["project"] == 0.0
|
||||||
|
assert values["project2"] == -10.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_very_small_histogram_values(self):
|
||||||
|
"""Test very small histogram values."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.observe_validation_latency(0.0001) # 0.1ms
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
sum_metric = next(
|
||||||
|
(m for m in all_metrics if m.name == "validation_latency_seconds_sum"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert sum_metric is not None
|
||||||
|
assert abs(sum_metric.value - 0.0001) < 0.00001
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_special_characters_in_labels(self):
|
||||||
|
"""Test special characters in label values."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.inc_validations("allow", agent_id="agent/with/slashes")
|
||||||
|
|
||||||
|
all_metrics = await metrics.get_all_metrics()
|
||||||
|
counters = [m for m in all_metrics if m.name == "safety_validations_total"]
|
||||||
|
|
||||||
|
# Should have the metric with special chars
|
||||||
|
assert len(counters) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_histogram_export(self):
|
||||||
|
"""Test exporting histogram with no observations."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
# No observations, but histogram buckets should still exist
|
||||||
|
prometheus = await metrics.get_prometheus_format()
|
||||||
|
|
||||||
|
assert "validation_latency_seconds" in prometheus
|
||||||
|
assert "le=" in prometheus
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prometheus_format_empty_label_value(self):
|
||||||
|
"""Test Prometheus format with empty label metrics."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.inc_approvals_granted() # Uses empty string as label
|
||||||
|
|
||||||
|
prometheus = await metrics.get_prometheus_format()
|
||||||
|
assert "safety_approvals_granted_total" in prometheus
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_resets(self):
|
||||||
|
"""Test multiple resets don't cause issues."""
|
||||||
|
metrics = SafetyMetrics()
|
||||||
|
|
||||||
|
await metrics.inc_validations("allow")
|
||||||
|
await metrics.reset()
|
||||||
|
await metrics.reset()
|
||||||
|
await metrics.reset()
|
||||||
|
|
||||||
|
summary = await metrics.get_summary()
|
||||||
|
assert summary["total_validations"] == 0
|
||||||
933
backend/tests/services/safety/test_permissions.py
Normal file
933
backend/tests/services/safety/test_permissions.py
Normal file
@@ -0,0 +1,933 @@
|
|||||||
|
"""Tests for Permission Manager.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- PermissionGrant: creation, expiry, matching, hierarchy
|
||||||
|
- PermissionManager: grant, revoke, check, require, list, defaults
|
||||||
|
- Edge cases: wildcards, expiration, default deny/allow
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.safety.exceptions import PermissionDeniedError
|
||||||
|
from app.services.safety.models import (
|
||||||
|
ActionMetadata,
|
||||||
|
ActionRequest,
|
||||||
|
ActionType,
|
||||||
|
PermissionLevel,
|
||||||
|
ResourceType,
|
||||||
|
)
|
||||||
|
from app.services.safety.permissions.manager import PermissionGrant, PermissionManager
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def action_metadata() -> ActionMetadata:
|
||||||
|
"""Create standard action metadata for tests."""
|
||||||
|
return ActionMetadata(
|
||||||
|
agent_id="test-agent",
|
||||||
|
project_id="test-project",
|
||||||
|
session_id="test-session",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def permission_manager() -> PermissionManager:
|
||||||
|
"""Create a PermissionManager for testing."""
|
||||||
|
return PermissionManager(default_deny=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def permissive_manager() -> PermissionManager:
|
||||||
|
"""Create a PermissionManager with default_deny=False."""
|
||||||
|
return PermissionManager(default_deny=False)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# PermissionGrant Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissionGrant:
|
||||||
|
"""Tests for the PermissionGrant class."""
|
||||||
|
|
||||||
|
def test_grant_creation(self) -> None:
|
||||||
|
"""Test basic grant creation."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
granted_by="admin",
|
||||||
|
reason="Read access to data directory",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.id is not None
|
||||||
|
assert grant.agent_id == "agent-1"
|
||||||
|
assert grant.resource_pattern == "/data/*"
|
||||||
|
assert grant.resource_type == ResourceType.FILE
|
||||||
|
assert grant.level == PermissionLevel.READ
|
||||||
|
assert grant.granted_by == "admin"
|
||||||
|
assert grant.reason == "Read access to data directory"
|
||||||
|
assert grant.expires_at is None
|
||||||
|
assert grant.created_at is not None
|
||||||
|
|
||||||
|
def test_grant_with_expiration(self) -> None:
|
||||||
|
"""Test grant with expiration time."""
|
||||||
|
future = datetime.utcnow() + timedelta(hours=1)
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
expires_at=future,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.expires_at == future
|
||||||
|
assert grant.is_expired() is False
|
||||||
|
|
||||||
|
def test_is_expired_no_expiration(self) -> None:
|
||||||
|
"""Test is_expired with no expiration set."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.is_expired() is False
|
||||||
|
|
||||||
|
def test_is_expired_future(self) -> None:
|
||||||
|
"""Test is_expired with future expiration."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
expires_at=datetime.utcnow() + timedelta(hours=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.is_expired() is False
|
||||||
|
|
||||||
|
def test_is_expired_past(self) -> None:
|
||||||
|
"""Test is_expired with past expiration."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
expires_at=datetime.utcnow() - timedelta(hours=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.is_expired() is True
|
||||||
|
|
||||||
|
def test_matches_exact(self) -> None:
|
||||||
|
"""Test matching with exact pattern."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
|
||||||
|
assert grant.matches("/data/other.txt", ResourceType.FILE) is False
|
||||||
|
|
||||||
|
def test_matches_wildcard(self) -> None:
|
||||||
|
"""Test matching with wildcard pattern."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
|
||||||
|
# fnmatch's * matches everything including /
|
||||||
|
assert grant.matches("/data/subdir/file.txt", ResourceType.FILE) is True
|
||||||
|
assert grant.matches("/other/file.txt", ResourceType.FILE) is False
|
||||||
|
|
||||||
|
def test_matches_recursive_wildcard(self) -> None:
|
||||||
|
"""Test matching with recursive pattern."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/**",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fnmatch treats ** similar to * - both match everything including /
|
||||||
|
assert grant.matches("/data/file.txt", ResourceType.FILE) is True
|
||||||
|
assert grant.matches("/data/subdir/file.txt", ResourceType.FILE) is True
|
||||||
|
|
||||||
|
def test_matches_wrong_resource_type(self) -> None:
|
||||||
|
"""Test matching fails with wrong resource type."""
|
||||||
|
grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Same pattern but different resource type
|
||||||
|
assert grant.matches("/data/table", ResourceType.DATABASE) is False
|
||||||
|
|
||||||
|
def test_allows_hierarchy(self) -> None:
|
||||||
|
"""Test permission level hierarchy."""
|
||||||
|
admin_grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.ADMIN,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ADMIN allows all levels
|
||||||
|
assert admin_grant.allows(PermissionLevel.NONE) is True
|
||||||
|
assert admin_grant.allows(PermissionLevel.READ) is True
|
||||||
|
assert admin_grant.allows(PermissionLevel.WRITE) is True
|
||||||
|
assert admin_grant.allows(PermissionLevel.EXECUTE) is True
|
||||||
|
assert admin_grant.allows(PermissionLevel.DELETE) is True
|
||||||
|
assert admin_grant.allows(PermissionLevel.ADMIN) is True
|
||||||
|
|
||||||
|
def test_allows_read_only(self) -> None:
|
||||||
|
"""Test READ grant only allows READ and NONE."""
|
||||||
|
read_grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert read_grant.allows(PermissionLevel.NONE) is True
|
||||||
|
assert read_grant.allows(PermissionLevel.READ) is True
|
||||||
|
assert read_grant.allows(PermissionLevel.WRITE) is False
|
||||||
|
assert read_grant.allows(PermissionLevel.EXECUTE) is False
|
||||||
|
assert read_grant.allows(PermissionLevel.DELETE) is False
|
||||||
|
assert read_grant.allows(PermissionLevel.ADMIN) is False
|
||||||
|
|
||||||
|
def test_allows_write_includes_read(self) -> None:
|
||||||
|
"""Test WRITE grant includes READ."""
|
||||||
|
write_grant = PermissionGrant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert write_grant.allows(PermissionLevel.READ) is True
|
||||||
|
assert write_grant.allows(PermissionLevel.WRITE) is True
|
||||||
|
assert write_grant.allows(PermissionLevel.EXECUTE) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# PermissionManager Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissionManager:
|
||||||
|
"""Tests for the PermissionManager class."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grant_creates_permission(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test granting a permission."""
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
granted_by="admin",
|
||||||
|
reason="Read access",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.id is not None
|
||||||
|
assert grant.agent_id == "agent-1"
|
||||||
|
assert grant.resource_pattern == "/data/*"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_grant_with_duration(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test granting a temporary permission."""
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
duration_seconds=3600, # 1 hour
|
||||||
|
)
|
||||||
|
|
||||||
|
assert grant.expires_at is not None
|
||||||
|
assert grant.is_expired() is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_by_id(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test revoking a grant by ID."""
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
success = await permission_manager.revoke(grant.id)
|
||||||
|
assert success is True
|
||||||
|
|
||||||
|
# Verify grant is removed
|
||||||
|
grants = await permission_manager.list_grants(agent_id="agent-1")
|
||||||
|
assert len(grants) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_nonexistent(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test revoking a non-existent grant."""
|
||||||
|
success = await permission_manager.revoke("nonexistent-id")
|
||||||
|
assert success is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_all_for_agent(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test revoking all permissions for an agent."""
|
||||||
|
# Grant multiple permissions
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/api/*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-2",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
revoked = await permission_manager.revoke_all("agent-1")
|
||||||
|
assert revoked == 2
|
||||||
|
|
||||||
|
# Verify agent-1 grants are gone
|
||||||
|
grants = await permission_manager.list_grants(agent_id="agent-1")
|
||||||
|
assert len(grants) == 0
|
||||||
|
|
||||||
|
# Verify agent-2 grant remains
|
||||||
|
grants = await permission_manager.list_grants(agent_id="agent-2")
|
||||||
|
assert len(grants) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_revoke_all_no_grants(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test revoking all when no grants exist."""
|
||||||
|
revoked = await permission_manager.revoke_all("nonexistent-agent")
|
||||||
|
assert revoked == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_granted(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test checking a granted permission."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_denied_default_deny(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test checking denied with default_deny=True."""
|
||||||
|
# No grants, should be denied
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_uses_default_permissions(
|
||||||
|
self,
|
||||||
|
permissive_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test that default permissions apply when default_deny=False."""
|
||||||
|
# No explicit grants, but FILE default is READ
|
||||||
|
allowed = await permissive_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
# But WRITE should fail
|
||||||
|
allowed = await permissive_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_shell_denied_by_default(
|
||||||
|
self,
|
||||||
|
permissive_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test SHELL is denied by default (NONE level)."""
|
||||||
|
allowed = await permissive_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="rm -rf /",
|
||||||
|
resource_type=ResourceType.SHELL,
|
||||||
|
required_level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_expired_grant_ignored(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test that expired grants are ignored in checks."""
|
||||||
|
# Create an already-expired grant
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
duration_seconds=1, # Very short
|
||||||
|
)
|
||||||
|
|
||||||
|
# Manually expire it
|
||||||
|
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_insufficient_level(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test check fails when grant level is insufficient."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to get WRITE access with only READ grant
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_action_file_read(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
action_metadata: ActionMetadata,
|
||||||
|
) -> None:
|
||||||
|
"""Test check_action for file read."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="test-agent",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
resource="/data/file.txt",
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check_action(action)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_action_file_write(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
action_metadata: ActionMetadata,
|
||||||
|
) -> None:
|
||||||
|
"""Test check_action for file write."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="test-agent",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_WRITE,
|
||||||
|
resource="/data/file.txt",
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check_action(action)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_action_uses_tool_name_as_resource(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
action_metadata: ActionMetadata,
|
||||||
|
) -> None:
|
||||||
|
"""Test check_action uses tool_name when resource is None."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="test-agent",
|
||||||
|
resource_pattern="search_*",
|
||||||
|
resource_type=ResourceType.CUSTOM,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.TOOL_CALL,
|
||||||
|
tool_name="search_documents",
|
||||||
|
resource=None,
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check_action(action)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_permission_granted(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test require_permission doesn't raise when granted."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await permission_manager.require_permission(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_require_permission_denied(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test require_permission raises when denied."""
|
||||||
|
with pytest.raises(PermissionDeniedError) as exc_info:
|
||||||
|
await permission_manager.require_permission(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/secret/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "/secret/file.txt" in str(exc_info.value)
|
||||||
|
assert exc_info.value.agent_id == "agent-1"
|
||||||
|
assert exc_info.value.required_permission == "read"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_grants_all(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing all grants."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-2",
|
||||||
|
resource_pattern="/api/*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
grants = await permission_manager.list_grants()
|
||||||
|
assert len(grants) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_grants_by_agent(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing grants filtered by agent."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-2",
|
||||||
|
resource_pattern="/api/*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
grants = await permission_manager.list_grants(agent_id="agent-1")
|
||||||
|
assert len(grants) == 1
|
||||||
|
assert grants[0].agent_id == "agent-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_grants_by_resource_type(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing grants filtered by resource type."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/api/*",
|
||||||
|
resource_type=ResourceType.API,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
grants = await permission_manager.list_grants(resource_type=ResourceType.FILE)
|
||||||
|
assert len(grants) == 1
|
||||||
|
assert grants[0].resource_type == ResourceType.FILE
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_grants_excludes_expired(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test that list_grants excludes expired grants."""
|
||||||
|
# Create expired grant
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/old/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
duration_seconds=1,
|
||||||
|
)
|
||||||
|
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
|
||||||
|
|
||||||
|
# Create valid grant
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/new/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
grants = await permission_manager.list_grants()
|
||||||
|
assert len(grants) == 1
|
||||||
|
assert grants[0].resource_pattern == "/new/*"
|
||||||
|
|
||||||
|
def test_set_default_permission(
|
||||||
|
self,
|
||||||
|
) -> None:
|
||||||
|
"""Test setting default permission level."""
|
||||||
|
manager = PermissionManager(default_deny=False)
|
||||||
|
|
||||||
|
# Default for SHELL is NONE
|
||||||
|
assert manager._default_permissions[ResourceType.SHELL] == PermissionLevel.NONE
|
||||||
|
|
||||||
|
# Change it
|
||||||
|
manager.set_default_permission(ResourceType.SHELL, PermissionLevel.EXECUTE)
|
||||||
|
assert (
|
||||||
|
manager._default_permissions[ResourceType.SHELL] == PermissionLevel.EXECUTE
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_default_permission_affects_checks(
|
||||||
|
self,
|
||||||
|
permissive_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test that changing default permissions affects checks."""
|
||||||
|
# Initially SHELL is NONE
|
||||||
|
allowed = await permissive_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="ls",
|
||||||
|
resource_type=ResourceType.SHELL,
|
||||||
|
required_level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
# Change default
|
||||||
|
permissive_manager.set_default_permission(
|
||||||
|
ResourceType.SHELL, PermissionLevel.EXECUTE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now should be allowed
|
||||||
|
allowed = await permissive_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="ls",
|
||||||
|
resource_type=ResourceType.SHELL,
|
||||||
|
required_level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Edge Cases
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissionEdgeCases:
|
||||||
|
"""Edge cases that could reveal hidden bugs."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_matching_grants(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test when multiple grants match - first sufficient one wins."""
|
||||||
|
# Grant READ on all files
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also grant WRITE on specific path
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/writable/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write on writable path should work
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/data/writable/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.WRITE,
|
||||||
|
)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wildcard_all_pattern(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test * pattern matches everything."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.ADMIN,
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/any/path/anywhere/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.DELETE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# fnmatch's * matches everything including /
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_question_mark_wildcard(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test ? wildcard matches single character."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="file?.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="file1.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="file10.txt", # Two characters, won't match
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_concurrent_grant_revoke(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test concurrent grant and revoke operations."""
|
||||||
|
|
||||||
|
async def grant_many():
|
||||||
|
grants = []
|
||||||
|
for i in range(10):
|
||||||
|
g = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern=f"/path{i}/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
grants.append(g)
|
||||||
|
return grants
|
||||||
|
|
||||||
|
async def revoke_many(grants):
|
||||||
|
for g in grants:
|
||||||
|
await permission_manager.revoke(g.id)
|
||||||
|
|
||||||
|
grants = await grant_many()
|
||||||
|
await revoke_many(grants)
|
||||||
|
|
||||||
|
# All should be revoked
|
||||||
|
remaining = await permission_manager.list_grants(agent_id="agent-1")
|
||||||
|
assert len(remaining) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_action_with_no_resource_or_tool(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
action_metadata: ActionMetadata,
|
||||||
|
) -> None:
|
||||||
|
"""Test check_action when both resource and tool_name are None."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="test-agent",
|
||||||
|
resource_pattern="*",
|
||||||
|
resource_type=ResourceType.LLM,
|
||||||
|
level=PermissionLevel.EXECUTE,
|
||||||
|
)
|
||||||
|
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.LLM_CALL,
|
||||||
|
resource=None,
|
||||||
|
tool_name=None,
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should use "*" as fallback
|
||||||
|
allowed = await permission_manager.check_action(action)
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_expired_called_on_check(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test that expired grants are cleaned up during check."""
|
||||||
|
# Create expired grant
|
||||||
|
grant = await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/old/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
duration_seconds=1,
|
||||||
|
)
|
||||||
|
grant.expires_at = datetime.utcnow() - timedelta(seconds=10)
|
||||||
|
|
||||||
|
# Create valid grant
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/new/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run a check - this should trigger cleanup
|
||||||
|
await permission_manager.check(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource="/new/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now verify expired grant was cleaned up
|
||||||
|
async with permission_manager._lock:
|
||||||
|
assert len(permission_manager._grants) == 1
|
||||||
|
assert permission_manager._grants[0].resource_pattern == "/new/*"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_wrong_agent_id(
|
||||||
|
self,
|
||||||
|
permission_manager: PermissionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test check fails for different agent."""
|
||||||
|
await permission_manager.grant(
|
||||||
|
agent_id="agent-1",
|
||||||
|
resource_pattern="/data/*",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Different agent should not have access
|
||||||
|
allowed = await permission_manager.check(
|
||||||
|
agent_id="agent-2",
|
||||||
|
resource="/data/file.txt",
|
||||||
|
resource_type=ResourceType.FILE,
|
||||||
|
required_level=PermissionLevel.READ,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert allowed is False
|
||||||
823
backend/tests/services/safety/test_rollback.py
Normal file
823
backend/tests/services/safety/test_rollback.py
Normal file
@@ -0,0 +1,823 @@
|
|||||||
|
"""Tests for Rollback Manager.
|
||||||
|
|
||||||
|
Tests cover:
|
||||||
|
- FileCheckpoint: state storage
|
||||||
|
- RollbackManager: checkpoint, rollback, cleanup
|
||||||
|
- TransactionContext: auto-rollback, commit, manual rollback
|
||||||
|
- Edge cases: non-existent files, partial failures, expiration
|
||||||
|
"""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.services.safety.exceptions import RollbackError
|
||||||
|
from app.services.safety.models import (
|
||||||
|
ActionMetadata,
|
||||||
|
ActionRequest,
|
||||||
|
ActionType,
|
||||||
|
CheckpointType,
|
||||||
|
)
|
||||||
|
from app.services.safety.rollback.manager import (
|
||||||
|
FileCheckpoint,
|
||||||
|
RollbackManager,
|
||||||
|
TransactionContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def action_metadata() -> ActionMetadata:
|
||||||
|
"""Create standard action metadata for tests."""
|
||||||
|
return ActionMetadata(
|
||||||
|
agent_id="test-agent",
|
||||||
|
project_id="test-project",
|
||||||
|
session_id="test-session",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def action_request(action_metadata: ActionMetadata) -> ActionRequest:
|
||||||
|
"""Create a standard action request for tests."""
|
||||||
|
return ActionRequest(
|
||||||
|
id="action-123",
|
||||||
|
action_type=ActionType.FILE_WRITE,
|
||||||
|
tool_name="file_write",
|
||||||
|
resource="/tmp/test_file.txt", # noqa: S108
|
||||||
|
metadata=action_metadata,
|
||||||
|
is_destructive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def rollback_manager() -> RollbackManager:
|
||||||
|
"""Create a RollbackManager for testing."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
with patch("app.services.safety.rollback.manager.get_safety_config") as mock:
|
||||||
|
mock.return_value = MagicMock(
|
||||||
|
checkpoint_dir=tmpdir,
|
||||||
|
checkpoint_retention_hours=24,
|
||||||
|
)
|
||||||
|
manager = RollbackManager(checkpoint_dir=tmpdir, retention_hours=24)
|
||||||
|
yield manager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir() -> Path:
|
||||||
|
"""Create a temporary directory for file operations."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
yield Path(tmpdir)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# FileCheckpoint Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileCheckpoint:
|
||||||
|
"""Tests for the FileCheckpoint class."""
|
||||||
|
|
||||||
|
def test_file_checkpoint_creation(self) -> None:
|
||||||
|
"""Test creating a file checkpoint."""
|
||||||
|
fc = FileCheckpoint(
|
||||||
|
checkpoint_id="cp-123",
|
||||||
|
file_path="/path/to/file.txt",
|
||||||
|
original_content=b"original content",
|
||||||
|
existed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fc.checkpoint_id == "cp-123"
|
||||||
|
assert fc.file_path == "/path/to/file.txt"
|
||||||
|
assert fc.original_content == b"original content"
|
||||||
|
assert fc.existed is True
|
||||||
|
assert fc.created_at is not None
|
||||||
|
|
||||||
|
def test_file_checkpoint_nonexistent_file(self) -> None:
|
||||||
|
"""Test checkpoint for non-existent file."""
|
||||||
|
fc = FileCheckpoint(
|
||||||
|
checkpoint_id="cp-123",
|
||||||
|
file_path="/path/to/new_file.txt",
|
||||||
|
original_content=None,
|
||||||
|
existed=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fc.original_content is None
|
||||||
|
assert fc.existed is False
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# RollbackManager Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestRollbackManager:
|
||||||
|
"""Tests for the RollbackManager class."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating a checkpoint."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(
|
||||||
|
action=action_request,
|
||||||
|
checkpoint_type=CheckpointType.FILE,
|
||||||
|
description="Test checkpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert checkpoint.id is not None
|
||||||
|
assert checkpoint.action_id == action_request.id
|
||||||
|
assert checkpoint.checkpoint_type == CheckpointType.FILE
|
||||||
|
assert checkpoint.description == "Test checkpoint"
|
||||||
|
assert checkpoint.expires_at is not None
|
||||||
|
assert checkpoint.is_valid is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_checkpoint_default_description(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpoint with default description."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
assert "file_write" in checkpoint.description
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_file_exists(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing an existing file."""
|
||||||
|
# Create a file
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original content")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Verify checkpoint was stored
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
||||||
|
assert len(file_checkpoints) == 1
|
||||||
|
assert file_checkpoints[0].existed is True
|
||||||
|
assert file_checkpoints[0].original_content == b"original content"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_file_not_exists(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing a non-existent file."""
|
||||||
|
test_file = temp_dir / "new_file.txt"
|
||||||
|
assert not test_file.exists()
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Verify checkpoint was stored
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
||||||
|
assert len(file_checkpoints) == 1
|
||||||
|
assert file_checkpoints[0].existed is False
|
||||||
|
assert file_checkpoints[0].original_content is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_files_multiple(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing multiple files."""
|
||||||
|
# Create files
|
||||||
|
file1 = temp_dir / "file1.txt"
|
||||||
|
file2 = temp_dir / "file2.txt"
|
||||||
|
file1.write_text("content 1")
|
||||||
|
file2.write_text("content 2")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_files(
|
||||||
|
checkpoint.id,
|
||||||
|
[str(file1), str(file2)],
|
||||||
|
)
|
||||||
|
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
file_checkpoints = rollback_manager._file_checkpoints.get(checkpoint.id, [])
|
||||||
|
assert len(file_checkpoints) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_restore_modified_file(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback restores modified file content."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original content")
|
||||||
|
|
||||||
|
# Create checkpoint
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Modify file
|
||||||
|
test_file.write_text("modified content")
|
||||||
|
assert test_file.read_text() == "modified content"
|
||||||
|
|
||||||
|
# Rollback
|
||||||
|
result = await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert len(result.actions_rolled_back) == 1
|
||||||
|
assert test_file.read_text() == "original content"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_delete_new_file(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback deletes file that didn't exist before."""
|
||||||
|
test_file = temp_dir / "new_file.txt"
|
||||||
|
assert not test_file.exists()
|
||||||
|
|
||||||
|
# Create checkpoint before file exists
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Create the file
|
||||||
|
test_file.write_text("new content")
|
||||||
|
assert test_file.exists()
|
||||||
|
|
||||||
|
# Rollback
|
||||||
|
result = await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert not test_file.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_not_found(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback with non-existent checkpoint."""
|
||||||
|
with pytest.raises(RollbackError) as exc_info:
|
||||||
|
await rollback_manager.rollback("nonexistent-id")
|
||||||
|
|
||||||
|
assert "not found" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_invalid_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback with invalidated checkpoint."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Rollback once (invalidates checkpoint)
|
||||||
|
await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
# Try to rollback again
|
||||||
|
with pytest.raises(RollbackError) as exc_info:
|
||||||
|
await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert "no longer valid" in str(exc_info.value)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discard_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test discarding a checkpoint."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
result = await rollback_manager.discard_checkpoint(checkpoint.id)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Verify it's gone
|
||||||
|
cp = await rollback_manager.get_checkpoint(checkpoint.id)
|
||||||
|
assert cp is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discard_checkpoint_nonexistent(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test discarding a non-existent checkpoint."""
|
||||||
|
result = await rollback_manager.discard_checkpoint("nonexistent-id")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting a checkpoint by ID."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
retrieved = await rollback_manager.get_checkpoint(checkpoint.id)
|
||||||
|
assert retrieved is not None
|
||||||
|
assert retrieved.id == checkpoint.id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_checkpoint_nonexistent(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test getting a non-existent checkpoint."""
|
||||||
|
retrieved = await rollback_manager.get_checkpoint("nonexistent-id")
|
||||||
|
assert retrieved is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing checkpoints."""
|
||||||
|
await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints()
|
||||||
|
assert len(checkpoints) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints_by_action(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_metadata: ActionMetadata,
|
||||||
|
) -> None:
|
||||||
|
"""Test listing checkpoints filtered by action."""
|
||||||
|
action1 = ActionRequest(
|
||||||
|
id="action-1",
|
||||||
|
action_type=ActionType.FILE_WRITE,
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
action2 = ActionRequest(
|
||||||
|
id="action-2",
|
||||||
|
action_type=ActionType.FILE_WRITE,
|
||||||
|
metadata=action_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await rollback_manager.create_checkpoint(action=action1)
|
||||||
|
await rollback_manager.create_checkpoint(action=action2)
|
||||||
|
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints(action_id="action-1")
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
assert checkpoints[0].action_id == "action-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints_excludes_expired(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test list_checkpoints excludes expired by default."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
# Manually expire it
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
rollback_manager._checkpoints[checkpoint.id].expires_at = (
|
||||||
|
datetime.utcnow() - timedelta(hours=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints()
|
||||||
|
assert len(checkpoints) == 0
|
||||||
|
|
||||||
|
# With include_expired=True
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints(include_expired=True)
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_expired(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test cleaning up expired checkpoints."""
|
||||||
|
# Create checkpoints
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("content")
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Expire it
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
rollback_manager._checkpoints[checkpoint.id].expires_at = (
|
||||||
|
datetime.utcnow() - timedelta(hours=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
count = await rollback_manager.cleanup_expired()
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
# Verify it's gone
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
assert checkpoint.id not in rollback_manager._checkpoints
|
||||||
|
assert checkpoint.id not in rollback_manager._file_checkpoints
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# TransactionContext Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransactionContext:
|
||||||
|
"""Tests for the TransactionContext class."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_creates_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test that entering context creates a checkpoint."""
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
assert tx.checkpoint_id is not None
|
||||||
|
|
||||||
|
# Verify checkpoint exists
|
||||||
|
cp = await rollback_manager.get_checkpoint(tx.checkpoint_id)
|
||||||
|
assert cp is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_checkpoint_file(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing files through context."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
|
||||||
|
# Modify file
|
||||||
|
test_file.write_text("modified")
|
||||||
|
|
||||||
|
# Manual rollback
|
||||||
|
result = await tx.rollback()
|
||||||
|
assert result is not None
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
assert test_file.read_text() == "original"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_checkpoint_files(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing multiple files through context."""
|
||||||
|
file1 = temp_dir / "file1.txt"
|
||||||
|
file2 = temp_dir / "file2.txt"
|
||||||
|
file1.write_text("content 1")
|
||||||
|
file2.write_text("content 2")
|
||||||
|
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
await tx.checkpoint_files([str(file1), str(file2)])
|
||||||
|
|
||||||
|
cp_id = tx.checkpoint_id
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
file_cps = rollback_manager._file_checkpoints.get(cp_id, [])
|
||||||
|
assert len(file_cps) == 2
|
||||||
|
|
||||||
|
tx.commit()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_auto_rollback_on_exception(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test auto-rollback when exception occurs."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
test_file.write_text("modified")
|
||||||
|
raise ValueError("Simulated error")
|
||||||
|
|
||||||
|
# Should have been rolled back
|
||||||
|
assert test_file.read_text() == "original"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_commit_prevents_rollback(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test that commit prevents auto-rollback."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
test_file.write_text("modified")
|
||||||
|
tx.commit()
|
||||||
|
raise ValueError("Simulated error after commit")
|
||||||
|
|
||||||
|
# Should NOT have been rolled back
|
||||||
|
assert test_file.read_text() == "modified"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_discards_checkpoint_on_commit(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test that checkpoint is discarded after successful commit."""
|
||||||
|
checkpoint_id = None
|
||||||
|
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
checkpoint_id = tx.checkpoint_id
|
||||||
|
tx.commit()
|
||||||
|
|
||||||
|
# Checkpoint should be discarded
|
||||||
|
cp = await rollback_manager.get_checkpoint(checkpoint_id)
|
||||||
|
assert cp is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_no_auto_rollback_when_disabled(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test that auto_rollback=False disables auto-rollback."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with TransactionContext(
|
||||||
|
rollback_manager,
|
||||||
|
action_request,
|
||||||
|
auto_rollback=False,
|
||||||
|
) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
test_file.write_text("modified")
|
||||||
|
raise ValueError("Simulated error")
|
||||||
|
|
||||||
|
# Should NOT have been rolled back
|
||||||
|
assert test_file.read_text() == "modified"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_manual_rollback(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test manual rollback within context."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
async with TransactionContext(rollback_manager, action_request) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
test_file.write_text("modified")
|
||||||
|
|
||||||
|
# Manual rollback
|
||||||
|
result = await tx.rollback()
|
||||||
|
assert result is not None
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
assert test_file.read_text() == "original"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_rollback_without_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback when checkpoint is None."""
|
||||||
|
tx = TransactionContext(rollback_manager, action_request)
|
||||||
|
# Don't enter context, so _checkpoint is None
|
||||||
|
result = await tx.rollback()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_checkpoint_file_without_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpoint_file when checkpoint is None (no-op)."""
|
||||||
|
tx = TransactionContext(rollback_manager, action_request)
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("content")
|
||||||
|
|
||||||
|
# Should not raise - just a no-op
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
await tx.checkpoint_files([str(test_file)])
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Edge Cases
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestRollbackEdgeCases:
|
||||||
|
"""Edge cases that could reveal hidden bugs."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_file_for_unknown_checkpoint(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test checkpointing file for non-existent checkpoint."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("content")
|
||||||
|
|
||||||
|
# Should create the list if it doesn't exist
|
||||||
|
await rollback_manager.checkpoint_file("unknown-checkpoint", str(test_file))
|
||||||
|
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
assert "unknown-checkpoint" in rollback_manager._file_checkpoints
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_with_partial_failure(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback when some files fail to restore."""
|
||||||
|
file1 = temp_dir / "file1.txt"
|
||||||
|
file1.write_text("original 1")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(file1))
|
||||||
|
|
||||||
|
# Add a file checkpoint with a path that will fail
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
# Create a checkpoint for a file in a non-writable location
|
||||||
|
bad_fc = FileCheckpoint(
|
||||||
|
checkpoint_id=checkpoint.id,
|
||||||
|
file_path="/nonexistent/path/file.txt",
|
||||||
|
original_content=b"content",
|
||||||
|
existed=True,
|
||||||
|
)
|
||||||
|
rollback_manager._file_checkpoints[checkpoint.id].append(bad_fc)
|
||||||
|
|
||||||
|
# Rollback - partial failure expected
|
||||||
|
result = await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert len(result.actions_rolled_back) == 1
|
||||||
|
assert len(result.failed_actions) == 1
|
||||||
|
assert "Failed to rollback" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_file_creates_parent_dirs(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test that rollback creates parent directories if needed."""
|
||||||
|
nested_file = temp_dir / "subdir" / "nested" / "file.txt"
|
||||||
|
nested_file.parent.mkdir(parents=True)
|
||||||
|
nested_file.write_text("original")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(nested_file))
|
||||||
|
|
||||||
|
# Delete the entire directory structure
|
||||||
|
nested_file.unlink()
|
||||||
|
(temp_dir / "subdir" / "nested").rmdir()
|
||||||
|
(temp_dir / "subdir").rmdir()
|
||||||
|
|
||||||
|
# Rollback should recreate
|
||||||
|
result = await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert nested_file.exists()
|
||||||
|
assert nested_file.read_text() == "original"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rollback_file_already_correct(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test rollback when file already has correct content."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
await rollback_manager.checkpoint_file(checkpoint.id, str(test_file))
|
||||||
|
|
||||||
|
# Don't modify file - rollback should still succeed
|
||||||
|
result = await rollback_manager.rollback(checkpoint.id)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert test_file.read_text() == "original"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkpoint_with_none_expires_at(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test list_checkpoints handles None expires_at."""
|
||||||
|
checkpoint = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
# Set expires_at to None
|
||||||
|
async with rollback_manager._lock:
|
||||||
|
rollback_manager._checkpoints[checkpoint.id].expires_at = None
|
||||||
|
|
||||||
|
# Should still be listed
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints()
|
||||||
|
assert len(checkpoints) == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_rollback_failure_logged(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
temp_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Test that auto-rollback failure is logged, not raised."""
|
||||||
|
test_file = temp_dir / "test.txt"
|
||||||
|
test_file.write_text("original")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
rollback_manager, "rollback", side_effect=Exception("Rollback failed!")
|
||||||
|
):
|
||||||
|
with patch("app.services.safety.rollback.manager.logger") as mock_logger:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
async with TransactionContext(
|
||||||
|
rollback_manager, action_request
|
||||||
|
) as tx:
|
||||||
|
await tx.checkpoint_file(str(test_file))
|
||||||
|
test_file.write_text("modified")
|
||||||
|
raise ValueError("Original error")
|
||||||
|
|
||||||
|
# Rollback error should be logged
|
||||||
|
mock_logger.error.assert_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_checkpoints_same_action(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test creating multiple checkpoints for the same action."""
|
||||||
|
cp1 = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
cp2 = await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
assert cp1.id != cp2.id
|
||||||
|
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints(
|
||||||
|
action_id=action_request.id
|
||||||
|
)
|
||||||
|
assert len(checkpoints) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_expired_with_no_expired(
|
||||||
|
self,
|
||||||
|
rollback_manager: RollbackManager,
|
||||||
|
action_request: ActionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Test cleanup when no checkpoints are expired."""
|
||||||
|
await rollback_manager.create_checkpoint(action=action_request)
|
||||||
|
|
||||||
|
count = await rollback_manager.cleanup_expired()
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
# Checkpoint should still exist
|
||||||
|
checkpoints = await rollback_manager.list_checkpoints()
|
||||||
|
assert len(checkpoints) == 1
|
||||||
@@ -363,6 +363,365 @@ class TestValidationBatch:
|
|||||||
assert results[1].decision == SafetyDecision.DENY
|
assert results[1].decision == SafetyDecision.DENY
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidationCache:
|
||||||
|
"""Tests for ValidationCache class."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_get_miss(self) -> None:
|
||||||
|
"""Test cache miss."""
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||||
|
result = await cache.get("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_get_hit(self) -> None:
|
||||||
|
"""Test cache hit."""
|
||||||
|
from app.services.safety.models import ValidationResult
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||||
|
vr = ValidationResult(
|
||||||
|
action_id="action-1",
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
applied_rules=[],
|
||||||
|
reasons=["test"],
|
||||||
|
)
|
||||||
|
await cache.set("key1", vr)
|
||||||
|
|
||||||
|
result = await cache.get("key1")
|
||||||
|
assert result is not None
|
||||||
|
assert result.action_id == "action-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_ttl_expiry(self) -> None:
|
||||||
|
"""Test cache TTL expiry."""
|
||||||
|
import time
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from app.services.safety.models import ValidationResult
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=10, ttl_seconds=1)
|
||||||
|
vr = ValidationResult(
|
||||||
|
action_id="action-1",
|
||||||
|
decision=SafetyDecision.ALLOW,
|
||||||
|
applied_rules=[],
|
||||||
|
reasons=["test"],
|
||||||
|
)
|
||||||
|
await cache.set("key1", vr)
|
||||||
|
|
||||||
|
# Advance time past TTL
|
||||||
|
with patch("time.time", return_value=time.time() + 2):
|
||||||
|
result = await cache.get("key1")
|
||||||
|
assert result is None # Should be expired
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_eviction_on_full(self) -> None:
|
||||||
|
"""Test cache eviction when full."""
|
||||||
|
from app.services.safety.models import ValidationResult
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=2, ttl_seconds=60)
|
||||||
|
|
||||||
|
vr1 = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
|
||||||
|
vr2 = ValidationResult(action_id="a2", decision=SafetyDecision.ALLOW)
|
||||||
|
vr3 = ValidationResult(action_id="a3", decision=SafetyDecision.ALLOW)
|
||||||
|
|
||||||
|
await cache.set("key1", vr1)
|
||||||
|
await cache.set("key2", vr2)
|
||||||
|
await cache.set("key3", vr3) # Should evict key1
|
||||||
|
|
||||||
|
# key1 should be evicted
|
||||||
|
assert await cache.get("key1") is None
|
||||||
|
assert await cache.get("key2") is not None
|
||||||
|
assert await cache.get("key3") is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_update_existing_key(self) -> None:
|
||||||
|
"""Test updating existing key in cache."""
|
||||||
|
from app.services.safety.models import ValidationResult
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||||
|
|
||||||
|
vr1 = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
|
||||||
|
vr2 = ValidationResult(action_id="a1-updated", decision=SafetyDecision.DENY)
|
||||||
|
|
||||||
|
await cache.set("key1", vr1)
|
||||||
|
await cache.set("key1", vr2) # Should update, not add
|
||||||
|
|
||||||
|
result = await cache.get("key1")
|
||||||
|
assert result is not None
|
||||||
|
assert result.action_id == "a1" # Still old value since we move_to_end
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_clear(self) -> None:
|
||||||
|
"""Test clearing cache."""
|
||||||
|
from app.services.safety.models import ValidationResult
|
||||||
|
from app.services.safety.validation.validator import ValidationCache
|
||||||
|
|
||||||
|
cache = ValidationCache(max_size=10, ttl_seconds=60)
|
||||||
|
|
||||||
|
vr = ValidationResult(action_id="a1", decision=SafetyDecision.ALLOW)
|
||||||
|
await cache.set("key1", vr)
|
||||||
|
await cache.set("key2", vr)
|
||||||
|
|
||||||
|
await cache.clear()
|
||||||
|
|
||||||
|
assert await cache.get("key1") is None
|
||||||
|
assert await cache.get("key2") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidatorCaching:
|
||||||
|
"""Tests for validator caching functionality."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cache_hit(self) -> None:
|
||||||
|
"""Test that cache is used for repeated validations."""
|
||||||
|
validator = ActionValidator(cache_enabled=True, cache_ttl=60)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name="file_read",
|
||||||
|
resource="/tmp/test.txt", # noqa: S108
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# First call populates cache
|
||||||
|
result1 = await validator.validate(action)
|
||||||
|
# Second call should use cache
|
||||||
|
result2 = await validator.validate(action)
|
||||||
|
|
||||||
|
assert result1.decision == result2.decision
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_cache(self) -> None:
|
||||||
|
"""Test clearing the validation cache."""
|
||||||
|
validator = ActionValidator(cache_enabled=True)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent", session_id="session-1")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name="file_read",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await validator.validate(action)
|
||||||
|
await validator.clear_cache()
|
||||||
|
|
||||||
|
# Cache should be empty now (no error)
|
||||||
|
result = await validator.validate(action)
|
||||||
|
assert result.decision == SafetyDecision.ALLOW
|
||||||
|
|
||||||
|
|
||||||
|
class TestRuleMatching:
|
||||||
|
"""Tests for rule matching edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_action_type_mismatch(self) -> None:
|
||||||
|
"""Test that rule doesn't match when action type doesn't match."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
validator.add_rule(
|
||||||
|
ValidationRule(
|
||||||
|
name="file_only",
|
||||||
|
action_types=[ActionType.FILE_READ],
|
||||||
|
decision=SafetyDecision.DENY,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.SHELL_COMMAND, # Different type
|
||||||
|
tool_name="shell_exec",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await validator.validate(action)
|
||||||
|
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_pattern_no_tool_name(self) -> None:
|
||||||
|
"""Test rule with tool pattern when action has no tool_name."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
validator.add_rule(
|
||||||
|
create_deny_rule(
|
||||||
|
name="deny_files",
|
||||||
|
tool_patterns=["file_*"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name=None, # No tool name
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await validator.validate(action)
|
||||||
|
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resource_pattern_no_resource(self) -> None:
|
||||||
|
"""Test rule with resource pattern when action has no resource."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
validator.add_rule(
|
||||||
|
create_deny_rule(
|
||||||
|
name="deny_secrets",
|
||||||
|
resource_patterns=["/secret/*"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name="file_read",
|
||||||
|
resource=None, # No resource
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await validator.validate(action)
|
||||||
|
assert result.decision == SafetyDecision.ALLOW # Rule didn't match
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resource_pattern_no_match(self) -> None:
|
||||||
|
"""Test rule with resource pattern that doesn't match."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
validator.add_rule(
|
||||||
|
create_deny_rule(
|
||||||
|
name="deny_secrets",
|
||||||
|
resource_patterns=["/secret/*"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name="file_read",
|
||||||
|
resource="/public/file.txt", # Doesn't match
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await validator.validate(action)
|
||||||
|
assert result.decision == SafetyDecision.ALLOW # Pattern didn't match
|
||||||
|
|
||||||
|
|
||||||
|
class TestPolicyLoading:
|
||||||
|
"""Tests for policy loading edge cases."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_rules_from_policy_with_validation_rules(self) -> None:
|
||||||
|
"""Test loading policy with explicit validation rules."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
|
||||||
|
rule = ValidationRule(
|
||||||
|
name="policy_rule",
|
||||||
|
tool_patterns=["test_*"],
|
||||||
|
decision=SafetyDecision.DENY,
|
||||||
|
reason="From policy",
|
||||||
|
)
|
||||||
|
policy = SafetyPolicy(
|
||||||
|
name="test",
|
||||||
|
validation_rules=[rule],
|
||||||
|
require_approval_for=[], # Clear defaults
|
||||||
|
denied_tools=[], # Clear defaults
|
||||||
|
)
|
||||||
|
|
||||||
|
validator.load_rules_from_policy(policy)
|
||||||
|
|
||||||
|
assert len(validator._rules) == 1
|
||||||
|
assert validator._rules[0].name == "policy_rule"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_approval_all_pattern(self) -> None:
|
||||||
|
"""Test loading policy with * approval pattern (all actions)."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
|
||||||
|
policy = SafetyPolicy(
|
||||||
|
name="test",
|
||||||
|
require_approval_for=["*"], # All actions require approval
|
||||||
|
denied_tools=[], # Clear defaults
|
||||||
|
)
|
||||||
|
|
||||||
|
validator.load_rules_from_policy(policy)
|
||||||
|
|
||||||
|
approval_rules = [
|
||||||
|
r for r in validator._rules if r.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||||
|
]
|
||||||
|
assert len(approval_rules) == 1
|
||||||
|
assert approval_rules[0].name == "require_approval_all"
|
||||||
|
assert approval_rules[0].action_types == list(ActionType)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_with_policy_loads_rules(self) -> None:
|
||||||
|
"""Test that validate() loads rules from policy if none exist."""
|
||||||
|
validator = ActionValidator(cache_enabled=False)
|
||||||
|
|
||||||
|
policy = SafetyPolicy(
|
||||||
|
name="test",
|
||||||
|
denied_tools=["dangerous_*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="test-agent")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.SHELL_COMMAND,
|
||||||
|
tool_name="dangerous_exec",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate with policy - should load rules
|
||||||
|
result = await validator.validate(action, policy=policy)
|
||||||
|
|
||||||
|
assert result.decision == SafetyDecision.DENY
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheKeyGeneration:
|
||||||
|
"""Tests for cache key generation."""
|
||||||
|
|
||||||
|
def test_get_cache_key(self) -> None:
|
||||||
|
"""Test cache key generation."""
|
||||||
|
validator = ActionValidator(cache_enabled=True)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(
|
||||||
|
agent_id="test-agent",
|
||||||
|
autonomy_level=AutonomyLevel.MILESTONE,
|
||||||
|
)
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.FILE_READ,
|
||||||
|
tool_name="file_read",
|
||||||
|
resource="/tmp/test.txt", # noqa: S108
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
key = validator._get_cache_key(action)
|
||||||
|
|
||||||
|
assert "file_read" in key
|
||||||
|
assert "file_read" in key
|
||||||
|
assert "/tmp/test.txt" in key # noqa: S108
|
||||||
|
assert "test-agent" in key
|
||||||
|
assert "milestone" in key
|
||||||
|
|
||||||
|
def test_get_cache_key_no_resource(self) -> None:
|
||||||
|
"""Test cache key generation without resource."""
|
||||||
|
validator = ActionValidator(cache_enabled=True)
|
||||||
|
|
||||||
|
metadata = ActionMetadata(agent_id="agent-1")
|
||||||
|
action = ActionRequest(
|
||||||
|
action_type=ActionType.SHELL_COMMAND,
|
||||||
|
tool_name="shell_exec",
|
||||||
|
resource=None,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
key = validator._get_cache_key(action)
|
||||||
|
|
||||||
|
# Should not error with None resource
|
||||||
|
assert "shell" in key
|
||||||
|
assert "agent-1" in key
|
||||||
|
|
||||||
|
|
||||||
class TestHelperFunctions:
|
class TestHelperFunctions:
|
||||||
"""Tests for rule creation helper functions."""
|
"""Tests for rule creation helper functions."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user