forked from cardosofelipe/fast-next-template
refactor(safety): apply consistent formatting across services and tests
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
This commit is contained in:
@@ -74,7 +74,9 @@ class ToolInfoResponse(BaseModel):
|
|||||||
name: str = Field(..., description="Tool name")
|
name: str = Field(..., description="Tool name")
|
||||||
description: str | None = Field(None, description="Tool description")
|
description: str | None = Field(None, description="Tool description")
|
||||||
server_name: str | None = Field(None, description="Server providing the tool")
|
server_name: str | None = Field(None, description="Server providing the tool")
|
||||||
input_schema: dict[str, Any] | None = Field(None, description="JSON schema for input")
|
input_schema: dict[str, Any] | None = Field(
|
||||||
|
None, description="JSON schema for input"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolListResponse(BaseModel):
|
class ToolListResponse(BaseModel):
|
||||||
|
|||||||
@@ -158,9 +158,7 @@ class MCPConfig(BaseModel):
|
|||||||
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
|
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
|
||||||
"""Get all enabled server configurations."""
|
"""Get all enabled server configurations."""
|
||||||
return {
|
return {
|
||||||
name: config
|
name: config for name, config in self.mcp_servers.items() if config.enabled
|
||||||
for name, config in self.mcp_servers.items()
|
|
||||||
if config.enabled
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def list_server_names(self) -> list[str]:
|
def list_server_names(self) -> list[str]:
|
||||||
|
|||||||
@@ -196,9 +196,7 @@ class AuditLogger:
|
|||||||
) -> AuditEvent:
|
) -> AuditEvent:
|
||||||
"""Log an action execution result."""
|
"""Log an action execution result."""
|
||||||
event_type = (
|
event_type = (
|
||||||
AuditEventType.ACTION_EXECUTED
|
AuditEventType.ACTION_EXECUTED if success else AuditEventType.ACTION_FAILED
|
||||||
if success
|
|
||||||
else AuditEventType.ACTION_FAILED
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await self.log(
|
return await self.log(
|
||||||
@@ -477,9 +475,7 @@ class AuditLogger:
|
|||||||
"user_id": event.user_id,
|
"user_id": event.user_id,
|
||||||
"decision": event.decision.value if event.decision else None,
|
"decision": event.decision.value if event.decision else None,
|
||||||
"details": {
|
"details": {
|
||||||
k: v
|
k: v for k, v in event.details.items() if not k.startswith("_")
|
||||||
for k, v in event.details.items()
|
|
||||||
if not k.startswith("_")
|
|
||||||
},
|
},
|
||||||
"correlation_id": event.correlation_id,
|
"correlation_id": event.correlation_id,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,9 +31,7 @@ class SafetyConfig(BaseSettings):
|
|||||||
|
|
||||||
# General settings
|
# General settings
|
||||||
enabled: bool = Field(True, description="Enable safety framework")
|
enabled: bool = Field(True, description="Enable safety framework")
|
||||||
strict_mode: bool = Field(
|
strict_mode: bool = Field(True, description="Strict mode (fail closed on errors)")
|
||||||
True, description="Strict mode (fail closed on errors)"
|
|
||||||
)
|
|
||||||
log_level: str = Field("INFO", description="Logging level")
|
log_level: str = Field("INFO", description="Logging level")
|
||||||
|
|
||||||
# Default autonomy level
|
# Default autonomy level
|
||||||
@@ -255,7 +253,8 @@ def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
|
|||||||
max_tokens_per_day=base_policy.max_tokens_per_day // 10,
|
max_tokens_per_day=base_policy.max_tokens_per_day // 10,
|
||||||
max_actions_per_minute=base_policy.max_actions_per_minute // 2,
|
max_actions_per_minute=base_policy.max_actions_per_minute // 2,
|
||||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
|
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
|
||||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute // 2,
|
max_file_operations_per_minute=base_policy.max_file_operations_per_minute
|
||||||
|
// 2,
|
||||||
denied_tools=["delete_*", "destroy_*", "drop_*"],
|
denied_tools=["delete_*", "destroy_*", "drop_*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -294,7 +293,8 @@ def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
|
|||||||
max_tokens_per_day=base_policy.max_tokens_per_day * 5,
|
max_tokens_per_day=base_policy.max_tokens_per_day * 5,
|
||||||
max_actions_per_minute=base_policy.max_actions_per_minute * 2,
|
max_actions_per_minute=base_policy.max_actions_per_minute * 2,
|
||||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
|
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
|
||||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute * 2,
|
max_file_operations_per_minute=base_policy.max_file_operations_per_minute
|
||||||
|
* 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -260,9 +260,15 @@ class ContentFilter:
|
|||||||
continue
|
continue
|
||||||
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
|
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
|
||||||
continue
|
continue
|
||||||
if pattern.category == ContentCategory.CREDENTIALS and not enable_secret_filter:
|
if (
|
||||||
|
pattern.category == ContentCategory.CREDENTIALS
|
||||||
|
and not enable_secret_filter
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
if pattern.category == ContentCategory.INJECTION and not enable_injection_filter:
|
if (
|
||||||
|
pattern.category == ContentCategory.INJECTION
|
||||||
|
and not enable_injection_filter
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
self._patterns.append(replace(pattern))
|
self._patterns.append(replace(pattern))
|
||||||
|
|
||||||
@@ -343,7 +349,10 @@ class ContentFilter:
|
|||||||
filtered_content = content
|
filtered_content = content
|
||||||
for match in all_matches:
|
for match in all_matches:
|
||||||
matched_pattern = self._get_pattern(match.pattern_name)
|
matched_pattern = self._get_pattern(match.pattern_name)
|
||||||
if matched_pattern and matched_pattern.action in (FilterAction.REDACT, FilterAction.BLOCK):
|
if matched_pattern and matched_pattern.action in (
|
||||||
|
FilterAction.REDACT,
|
||||||
|
FilterAction.BLOCK,
|
||||||
|
):
|
||||||
filtered_content = (
|
filtered_content = (
|
||||||
filtered_content[: match.start_pos]
|
filtered_content[: match.start_pos]
|
||||||
+ (match.redacted_text or "[REDACTED]")
|
+ (match.redacted_text or "[REDACTED]")
|
||||||
@@ -371,8 +380,12 @@ class ContentFilter:
|
|||||||
if raise_on_block:
|
if raise_on_block:
|
||||||
raise ContentFilterError(
|
raise ContentFilterError(
|
||||||
block_reason or "Content blocked",
|
block_reason or "Content blocked",
|
||||||
filter_type=all_matches[0].category.value if all_matches else "unknown",
|
filter_type=all_matches[0].category.value
|
||||||
detected_patterns=[m.pattern_name for m in all_matches] if all_matches else [],
|
if all_matches
|
||||||
|
else "unknown",
|
||||||
|
detected_patterns=[m.pattern_name for m in all_matches]
|
||||||
|
if all_matches
|
||||||
|
else [],
|
||||||
)
|
)
|
||||||
elif all_matches:
|
elif all_matches:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -480,9 +493,13 @@ class ContentFilter:
|
|||||||
matches = pattern.find_matches(content)
|
matches = pattern.find_matches(content)
|
||||||
for match in matches:
|
for match in matches:
|
||||||
if pattern.action == FilterAction.BLOCK:
|
if pattern.action == FilterAction.BLOCK:
|
||||||
issues.append(f"Blocked: {pattern.name} at position {match.start_pos}")
|
issues.append(
|
||||||
|
f"Blocked: {pattern.name} at position {match.start_pos}"
|
||||||
|
)
|
||||||
elif pattern.action == FilterAction.WARN and not allow_warnings:
|
elif pattern.action == FilterAction.WARN and not allow_warnings:
|
||||||
issues.append(f"Warning: {pattern.name} at position {match.start_pos}")
|
issues.append(
|
||||||
|
f"Warning: {pattern.name} at position {match.start_pos}"
|
||||||
|
)
|
||||||
|
|
||||||
return len(issues) == 0, issues
|
return len(issues) == 0, issues
|
||||||
|
|
||||||
|
|||||||
@@ -69,7 +69,9 @@ class BudgetTracker:
|
|||||||
else 0
|
else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
is_warning = max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
|
is_warning = (
|
||||||
|
max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
|
||||||
|
)
|
||||||
is_exceeded = (
|
is_exceeded = (
|
||||||
self._tokens_used >= self.tokens_limit
|
self._tokens_used >= self.tokens_limit
|
||||||
or self._cost_used_usd >= self.cost_limit_usd
|
or self._cost_used_usd >= self.cost_limit_usd
|
||||||
@@ -94,12 +96,16 @@ class BudgetTracker:
|
|||||||
reset_at=reset_at,
|
reset_at=reset_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def check_budget(self, estimated_tokens: int, estimated_cost_usd: float) -> bool:
|
async def check_budget(
|
||||||
|
self, estimated_tokens: int, estimated_cost_usd: float
|
||||||
|
) -> bool:
|
||||||
"""Check if there's enough budget for an operation."""
|
"""Check if there's enough budget for an operation."""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._check_reset()
|
self._check_reset()
|
||||||
|
|
||||||
would_exceed_tokens = (self._tokens_used + estimated_tokens) > self.tokens_limit
|
would_exceed_tokens = (
|
||||||
|
self._tokens_used + estimated_tokens
|
||||||
|
) > self.tokens_limit
|
||||||
would_exceed_cost = (
|
would_exceed_cost = (
|
||||||
self._cost_used_usd + estimated_cost_usd
|
self._cost_used_usd + estimated_cost_usd
|
||||||
) > self.cost_limit_usd
|
) > self.cost_limit_usd
|
||||||
@@ -241,13 +247,13 @@ class CostController:
|
|||||||
session_tracker = await self.get_or_create_tracker(
|
session_tracker = await self.get_or_create_tracker(
|
||||||
BudgetScope.SESSION, session_id
|
BudgetScope.SESSION, session_id
|
||||||
)
|
)
|
||||||
if not await session_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
if not await session_tracker.check_budget(
|
||||||
|
estimated_tokens, estimated_cost_usd
|
||||||
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check agent daily budget
|
# Check agent daily budget
|
||||||
agent_tracker = await self.get_or_create_tracker(
|
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
||||||
BudgetScope.DAILY, agent_id
|
|
||||||
)
|
|
||||||
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@@ -253,7 +253,9 @@ class EmergencyControls:
|
|||||||
self._on_resume_callbacks,
|
self._on_resume_callbacks,
|
||||||
{"scope": scope, "resumed_by": resumed_by},
|
{"scope": scope, "resumed_by": resumed_by},
|
||||||
)
|
)
|
||||||
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by})
|
await self._notify_handlers(
|
||||||
|
"resume", {"scope": scope, "resumed_by": resumed_by}
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@@ -266,9 +266,7 @@ class SafetyGuardian:
|
|||||||
|
|
||||||
except SafetyError as e:
|
except SafetyError as e:
|
||||||
# Known safety error
|
# Known safety error
|
||||||
return await self._create_denial_result(
|
return await self._create_denial_result(action, [str(e)], audit_events)
|
||||||
action, [str(e)], audit_events
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Unknown error - fail closed in strict mode
|
# Unknown error - fail closed in strict mode
|
||||||
logger.error("Unexpected error in safety validation: %s", e)
|
logger.error("Unexpected error in safety validation: %s", e)
|
||||||
@@ -391,7 +389,9 @@ class SafetyGuardian:
|
|||||||
if action.tool_name:
|
if action.tool_name:
|
||||||
for pattern in policy.denied_tools:
|
for pattern in policy.denied_tools:
|
||||||
if self._matches_pattern(action.tool_name, pattern):
|
if self._matches_pattern(action.tool_name, pattern):
|
||||||
reasons.append(f"Tool '{action.tool_name}' denied by pattern '{pattern}'")
|
reasons.append(
|
||||||
|
f"Tool '{action.tool_name}' denied by pattern '{pattern}'"
|
||||||
|
)
|
||||||
return GuardianResult(
|
return GuardianResult(
|
||||||
action_id=action.id,
|
action_id=action.id,
|
||||||
allowed=False,
|
allowed=False,
|
||||||
@@ -419,7 +419,9 @@ class SafetyGuardian:
|
|||||||
if action.resource:
|
if action.resource:
|
||||||
for pattern in policy.denied_file_patterns:
|
for pattern in policy.denied_file_patterns:
|
||||||
if self._matches_pattern(action.resource, pattern):
|
if self._matches_pattern(action.resource, pattern):
|
||||||
reasons.append(f"Resource '{action.resource}' denied by pattern '{pattern}'")
|
reasons.append(
|
||||||
|
f"Resource '{action.resource}' denied by pattern '{pattern}'"
|
||||||
|
)
|
||||||
return GuardianResult(
|
return GuardianResult(
|
||||||
action_id=action.id,
|
action_id=action.id,
|
||||||
allowed=False,
|
allowed=False,
|
||||||
|
|||||||
@@ -134,7 +134,9 @@ class LoopDetector:
|
|||||||
raise LoopDetectedError(
|
raise LoopDetectedError(
|
||||||
f"Loop detected: {loop_type}",
|
f"Loop detected: {loop_type}",
|
||||||
loop_type=loop_type or "unknown",
|
loop_type=loop_type or "unknown",
|
||||||
repetition_count=self._max_exact if loop_type == "exact" else self._max_semantic,
|
repetition_count=self._max_exact
|
||||||
|
if loop_type == "exact"
|
||||||
|
else self._max_semantic,
|
||||||
action_pattern=[signature.semantic_key()],
|
action_pattern=[signature.semantic_key()],
|
||||||
agent_id=action.metadata.agent_id,
|
agent_id=action.metadata.agent_id,
|
||||||
action_id=action.id,
|
action_id=action.id,
|
||||||
|
|||||||
@@ -293,7 +293,9 @@ class MCPSafetyWrapper:
|
|||||||
action_type=action_type,
|
action_type=action_type,
|
||||||
tool_name=tool_call.tool_name,
|
tool_name=tool_call.tool_name,
|
||||||
arguments=tool_call.arguments,
|
arguments=tool_call.arguments,
|
||||||
resource=tool_call.arguments.get("path", tool_call.arguments.get("resource")),
|
resource=tool_call.arguments.get(
|
||||||
|
"path", tool_call.arguments.get("resource")
|
||||||
|
),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -302,7 +304,9 @@ class MCPSafetyWrapper:
|
|||||||
tool_lower = tool_name.lower()
|
tool_lower = tool_name.lower()
|
||||||
|
|
||||||
# Check destructive patterns
|
# Check destructive patterns
|
||||||
if any(d in tool_lower for d in ["write", "create", "delete", "remove", "update"]):
|
if any(
|
||||||
|
d in tool_lower for d in ["write", "create", "delete", "remove", "update"]
|
||||||
|
):
|
||||||
if "file" in tool_lower:
|
if "file" in tool_lower:
|
||||||
if "delete" in tool_lower or "remove" in tool_lower:
|
if "delete" in tool_lower or "remove" in tool_lower:
|
||||||
return ActionType.FILE_DELETE
|
return ActionType.FILE_DELETE
|
||||||
|
|||||||
@@ -69,7 +69,18 @@ class SafetyMetrics:
|
|||||||
|
|
||||||
def _init_histogram_buckets(self) -> None:
|
def _init_histogram_buckets(self) -> None:
|
||||||
"""Initialize histogram buckets for latency metrics."""
|
"""Initialize histogram buckets for latency metrics."""
|
||||||
latency_buckets = [0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, float("inf")]
|
latency_buckets = [
|
||||||
|
0.01,
|
||||||
|
0.05,
|
||||||
|
0.1,
|
||||||
|
0.25,
|
||||||
|
0.5,
|
||||||
|
1.0,
|
||||||
|
2.5,
|
||||||
|
5.0,
|
||||||
|
10.0,
|
||||||
|
float("inf"),
|
||||||
|
]
|
||||||
|
|
||||||
for name in [
|
for name in [
|
||||||
"validation_latency_seconds",
|
"validation_latency_seconds",
|
||||||
@@ -321,7 +332,8 @@ class SafetyMetrics:
|
|||||||
async with self._lock:
|
async with self._lock:
|
||||||
total_validations = sum(self._counters["safety_validations_total"].values())
|
total_validations = sum(self._counters["safety_validations_total"].values())
|
||||||
denied_validations = sum(
|
denied_validations = sum(
|
||||||
v for k, v in self._counters["safety_validations_total"].items()
|
v
|
||||||
|
for k, v in self._counters["safety_validations_total"].items()
|
||||||
if "decision=deny" in k
|
if "decision=deny" in k
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -358,11 +370,13 @@ class SafetyMetrics:
|
|||||||
"rollbacks_executed": sum(
|
"rollbacks_executed": sum(
|
||||||
self._counters["safety_rollbacks_total"].values()
|
self._counters["safety_rollbacks_total"].values()
|
||||||
),
|
),
|
||||||
"mcp_calls": sum(
|
"mcp_calls": sum(self._counters["safety_mcp_calls_total"].values()),
|
||||||
self._counters["safety_mcp_calls_total"].values()
|
"pending_approvals": self._gauges.get(
|
||||||
),
|
"safety_pending_approvals", {}
|
||||||
"pending_approvals": self._gauges.get("safety_pending_approvals", {}).get("", 0),
|
).get("", 0),
|
||||||
"active_checkpoints": self._gauges.get("safety_active_checkpoints", {}).get("", 0),
|
"active_checkpoints": self._gauges.get(
|
||||||
|
"safety_active_checkpoints", {}
|
||||||
|
).get("", 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
|
|||||||
@@ -212,9 +212,7 @@ class ValidationResult(BaseModel):
|
|||||||
applied_rules: list[str] = Field(
|
applied_rules: list[str] = Field(
|
||||||
default_factory=list, description="IDs of applied rules"
|
default_factory=list, description="IDs of applied rules"
|
||||||
)
|
)
|
||||||
reasons: list[str] = Field(
|
reasons: list[str] = Field(default_factory=list, description="Reasons for decision")
|
||||||
default_factory=list, description="Reasons for decision"
|
|
||||||
)
|
|
||||||
approval_id: str | None = Field(None, description="Approval request ID if needed")
|
approval_id: str | None = Field(None, description="Approval request ID if needed")
|
||||||
retry_after_seconds: float | None = Field(
|
retry_after_seconds: float | None = Field(
|
||||||
None, description="Retry delay if rate limited"
|
None, description="Retry delay if rate limited"
|
||||||
@@ -267,9 +265,7 @@ class RateLimitConfig(BaseModel):
|
|||||||
limit: int = Field(..., description="Maximum allowed in window")
|
limit: int = Field(..., description="Maximum allowed in window")
|
||||||
window_seconds: int = Field(60, description="Time window in seconds")
|
window_seconds: int = Field(60, description="Time window in seconds")
|
||||||
burst_limit: int | None = Field(None, description="Burst allowance")
|
burst_limit: int | None = Field(None, description="Burst allowance")
|
||||||
slowdown_threshold: float = Field(
|
slowdown_threshold: float = Field(0.8, description="Start slowing at this fraction")
|
||||||
0.8, description="Start slowing at this fraction"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimitStatus(BaseModel):
|
class RateLimitStatus(BaseModel):
|
||||||
|
|||||||
@@ -66,9 +66,7 @@ class RollbackManager:
|
|||||||
"""
|
"""
|
||||||
config = get_safety_config()
|
config = get_safety_config()
|
||||||
|
|
||||||
self._checkpoint_dir = Path(
|
self._checkpoint_dir = Path(checkpoint_dir or config.checkpoint_dir)
|
||||||
checkpoint_dir or config.checkpoint_dir
|
|
||||||
)
|
|
||||||
self._retention_hours = retention_hours or config.checkpoint_retention_hours
|
self._retention_hours = retention_hours or config.checkpoint_retention_hours
|
||||||
|
|
||||||
self._checkpoints: dict[str, Checkpoint] = {}
|
self._checkpoints: dict[str, Checkpoint] = {}
|
||||||
@@ -231,7 +229,9 @@ class RollbackManager:
|
|||||||
success=success,
|
success=success,
|
||||||
actions_rolled_back=actions_rolled_back,
|
actions_rolled_back=actions_rolled_back,
|
||||||
failed_actions=failed_actions,
|
failed_actions=failed_actions,
|
||||||
error=None if success else f"Failed to rollback {len(failed_actions)} items",
|
error=None
|
||||||
|
if success
|
||||||
|
else f"Failed to rollback {len(failed_actions)} items",
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
@@ -294,8 +294,7 @@ class RollbackManager:
|
|||||||
|
|
||||||
if not include_expired:
|
if not include_expired:
|
||||||
checkpoints = [
|
checkpoints = [
|
||||||
c for c in checkpoints
|
c for c in checkpoints if c.expires_at is None or c.expires_at > now
|
||||||
if c.expires_at is None or c.expires_at > now
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return checkpoints
|
return checkpoints
|
||||||
|
|||||||
@@ -113,7 +113,9 @@ class ActionValidator:
|
|||||||
self._rules.append(rule)
|
self._rules.append(rule)
|
||||||
# Re-sort by priority (higher first)
|
# Re-sort by priority (higher first)
|
||||||
self._rules.sort(key=lambda r: r.priority, reverse=True)
|
self._rules.sort(key=lambda r: r.priority, reverse=True)
|
||||||
logger.debug("Added validation rule: %s (priority %d)", rule.name, rule.priority)
|
logger.debug(
|
||||||
|
"Added validation rule: %s (priority %d)", rule.name, rule.priority
|
||||||
|
)
|
||||||
|
|
||||||
def remove_rule(self, rule_id: str) -> bool:
|
def remove_rule(self, rule_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -44,8 +44,8 @@ def mock_superuser():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client(mock_mcp_client, mock_superuser):
|
def client(mock_mcp_client, mock_superuser):
|
||||||
"""Create a FastAPI test client with mocked dependencies."""
|
"""Create a FastAPI test client with mocked dependencies."""
|
||||||
from app.api.routes.mcp import get_mcp_client
|
|
||||||
from app.api.dependencies.permissions import require_superuser
|
from app.api.dependencies.permissions import require_superuser
|
||||||
|
from app.api.routes.mcp import get_mcp_client
|
||||||
|
|
||||||
# Override dependencies
|
# Override dependencies
|
||||||
async def override_get_mcp_client():
|
async def override_get_mcp_client():
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from app.services.mcp.client_manager import (
|
|||||||
shutdown_mcp_client,
|
shutdown_mcp_client,
|
||||||
)
|
)
|
||||||
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||||
from app.services.mcp.connection import ConnectionState
|
|
||||||
from app.services.mcp.exceptions import MCPServerNotFoundError
|
from app.services.mcp.exceptions import MCPServerNotFoundError
|
||||||
from app.services.mcp.registry import MCPServerRegistry
|
from app.services.mcp.registry import MCPServerRegistry
|
||||||
from app.services.mcp.routing import ToolInfo, ToolResult
|
from app.services.mcp.routing import ToolInfo, ToolResult
|
||||||
|
|||||||
@@ -4,10 +4,8 @@ Tests for MCP Configuration System
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
|
||||||
|
|
||||||
from app.services.mcp.config import (
|
from app.services.mcp.config import (
|
||||||
MCPConfig,
|
MCPConfig,
|
||||||
@@ -217,9 +215,7 @@ mcp_servers:
|
|||||||
default_timeout: 60
|
default_timeout: 60
|
||||||
connection_pool_size: 20
|
connection_pool_size: 20
|
||||||
"""
|
"""
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||||
mode="w", suffix=".yaml", delete=False
|
|
||||||
) as f:
|
|
||||||
f.write(yaml_content)
|
f.write(yaml_content)
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
@@ -248,9 +244,7 @@ mcp_servers:
|
|||||||
explicit-server:
|
explicit-server:
|
||||||
url: http://explicit:8000
|
url: http://explicit:8000
|
||||||
"""
|
"""
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||||
mode="w", suffix=".yaml", delete=False
|
|
||||||
) as f:
|
|
||||||
f.write(yaml_content)
|
f.write(yaml_content)
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
@@ -267,9 +261,7 @@ mcp_servers:
|
|||||||
env-server:
|
env-server:
|
||||||
url: http://env:8000
|
url: http://env:8000
|
||||||
"""
|
"""
|
||||||
with tempfile.NamedTemporaryFile(
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||||
mode="w", suffix=".yaml", delete=False
|
|
||||||
) as f:
|
|
||||||
f.write(yaml_content)
|
f.write(yaml_content)
|
||||||
f.flush()
|
f.flush()
|
||||||
|
|
||||||
|
|||||||
@@ -220,9 +220,7 @@ class TestMCPConnection:
|
|||||||
MockClient.return_value = mock_client
|
MockClient.return_value = mock_client
|
||||||
|
|
||||||
await conn.connect()
|
await conn.connect()
|
||||||
result = await conn.execute_request(
|
result = await conn.execute_request("POST", "/mcp", data={"method": "test"})
|
||||||
"POST", "/mcp", data={"method": "test"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == {"result": "success"}
|
assert result == {"result": "success"}
|
||||||
|
|
||||||
|
|||||||
@@ -160,11 +160,21 @@ class TestMCPToolNotFoundError:
|
|||||||
"""Test tool not found with available tools listed."""
|
"""Test tool not found with available tools listed."""
|
||||||
error = MCPToolNotFoundError(
|
error = MCPToolNotFoundError(
|
||||||
"unknown-tool",
|
"unknown-tool",
|
||||||
available_tools=["tool-1", "tool-2", "tool-3", "tool-4", "tool-5", "tool-6"],
|
available_tools=[
|
||||||
|
"tool-1",
|
||||||
|
"tool-2",
|
||||||
|
"tool-3",
|
||||||
|
"tool-4",
|
||||||
|
"tool-5",
|
||||||
|
"tool-6",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
assert len(error.available_tools) == 6
|
assert len(error.available_tools) == 6
|
||||||
# Should show first 5 tools with ellipsis
|
# Should show first 5 tools with ellipsis
|
||||||
assert "available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..." in str(error)
|
assert (
|
||||||
|
"available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..."
|
||||||
|
in str(error)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestMCPCircuitOpenError:
|
class TestMCPCircuitOpenError:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ Tests for MCP Server Registry
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.services.mcp.config import MCPConfig, MCPServerConfig, TransportType
|
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||||
from app.services.mcp.exceptions import MCPServerNotFoundError
|
from app.services.mcp.exceptions import MCPServerNotFoundError
|
||||||
from app.services.mcp.registry import (
|
from app.services.mcp.registry import (
|
||||||
MCPServerRegistry,
|
MCPServerRegistry,
|
||||||
|
|||||||
@@ -220,7 +220,9 @@ class TestScan:
|
|||||||
filter_all: ContentFilter,
|
filter_all: ContentFilter,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test scanning for specific categories only."""
|
"""Test scanning for specific categories only."""
|
||||||
content = "Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
|
content = (
|
||||||
|
"Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
|
||||||
|
)
|
||||||
|
|
||||||
# Scan only for secrets
|
# Scan only for secrets
|
||||||
matches = await filter_all.scan(
|
matches = await filter_all.scan(
|
||||||
|
|||||||
@@ -321,8 +321,7 @@ class TestLoadRulesFromPolicy:
|
|||||||
validator.load_rules_from_policy(policy)
|
validator.load_rules_from_policy(policy)
|
||||||
|
|
||||||
approval_rules = [
|
approval_rules = [
|
||||||
r for r in validator._rules
|
r for r in validator._rules if r.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||||
if r.decision == SafetyDecision.REQUIRE_APPROVAL
|
|
||||||
]
|
]
|
||||||
assert len(approval_rules) == 1
|
assert len(approval_rules) == 1
|
||||||
|
|
||||||
|
|||||||
@@ -162,7 +162,9 @@ export default function ProjectSettingsPage({ params }: ProjectSettingsPageProps
|
|||||||
<Card>
|
<Card>
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle>Autonomy Level</CardTitle>
|
<CardTitle>Autonomy Level</CardTitle>
|
||||||
<CardDescription>Control how much oversight you want over agent actions</CardDescription>
|
<CardDescription>
|
||||||
|
Control how much oversight you want over agent actions
|
||||||
|
</CardDescription>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="space-y-4">
|
<CardContent className="space-y-4">
|
||||||
<div className="space-y-2">
|
<div className="space-y-2">
|
||||||
|
|||||||
Reference in New Issue
Block a user