refactor(safety): apply consistent formatting across services and tests

Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
This commit is contained in:
2026-01-03 16:23:39 +01:00
parent 065e43c5a9
commit 520c06175e
23 changed files with 123 additions and 81 deletions

View File

@@ -74,7 +74,9 @@ class ToolInfoResponse(BaseModel):
name: str = Field(..., description="Tool name") name: str = Field(..., description="Tool name")
description: str | None = Field(None, description="Tool description") description: str | None = Field(None, description="Tool description")
server_name: str | None = Field(None, description="Server providing the tool") server_name: str | None = Field(None, description="Server providing the tool")
input_schema: dict[str, Any] | None = Field(None, description="JSON schema for input") input_schema: dict[str, Any] | None = Field(
None, description="JSON schema for input"
)
class ToolListResponse(BaseModel): class ToolListResponse(BaseModel):

View File

@@ -158,9 +158,7 @@ class MCPConfig(BaseModel):
def get_enabled_servers(self) -> dict[str, MCPServerConfig]: def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
"""Get all enabled server configurations.""" """Get all enabled server configurations."""
return { return {
name: config name: config for name, config in self.mcp_servers.items() if config.enabled
for name, config in self.mcp_servers.items()
if config.enabled
} }
def list_server_names(self) -> list[str]: def list_server_names(self) -> list[str]:

View File

@@ -196,9 +196,7 @@ class AuditLogger:
) -> AuditEvent: ) -> AuditEvent:
"""Log an action execution result.""" """Log an action execution result."""
event_type = ( event_type = (
AuditEventType.ACTION_EXECUTED AuditEventType.ACTION_EXECUTED if success else AuditEventType.ACTION_FAILED
if success
else AuditEventType.ACTION_FAILED
) )
return await self.log( return await self.log(
@@ -477,9 +475,7 @@ class AuditLogger:
"user_id": event.user_id, "user_id": event.user_id,
"decision": event.decision.value if event.decision else None, "decision": event.decision.value if event.decision else None,
"details": { "details": {
k: v k: v for k, v in event.details.items() if not k.startswith("_")
for k, v in event.details.items()
if not k.startswith("_")
}, },
"correlation_id": event.correlation_id, "correlation_id": event.correlation_id,
} }

View File

@@ -31,9 +31,7 @@ class SafetyConfig(BaseSettings):
# General settings # General settings
enabled: bool = Field(True, description="Enable safety framework") enabled: bool = Field(True, description="Enable safety framework")
strict_mode: bool = Field( strict_mode: bool = Field(True, description="Strict mode (fail closed on errors)")
True, description="Strict mode (fail closed on errors)"
)
log_level: str = Field("INFO", description="Logging level") log_level: str = Field("INFO", description="Logging level")
# Default autonomy level # Default autonomy level
@@ -255,7 +253,8 @@ def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
max_tokens_per_day=base_policy.max_tokens_per_day // 10, max_tokens_per_day=base_policy.max_tokens_per_day // 10,
max_actions_per_minute=base_policy.max_actions_per_minute // 2, max_actions_per_minute=base_policy.max_actions_per_minute // 2,
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2, max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute // 2, max_file_operations_per_minute=base_policy.max_file_operations_per_minute
// 2,
denied_tools=["delete_*", "destroy_*", "drop_*"], denied_tools=["delete_*", "destroy_*", "drop_*"],
) )
@@ -294,7 +293,8 @@ def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
max_tokens_per_day=base_policy.max_tokens_per_day * 5, max_tokens_per_day=base_policy.max_tokens_per_day * 5,
max_actions_per_minute=base_policy.max_actions_per_minute * 2, max_actions_per_minute=base_policy.max_actions_per_minute * 2,
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2, max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute * 2, max_file_operations_per_minute=base_policy.max_file_operations_per_minute
* 2,
) )

View File

@@ -260,9 +260,15 @@ class ContentFilter:
continue continue
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter: if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
continue continue
if pattern.category == ContentCategory.CREDENTIALS and not enable_secret_filter: if (
pattern.category == ContentCategory.CREDENTIALS
and not enable_secret_filter
):
continue continue
if pattern.category == ContentCategory.INJECTION and not enable_injection_filter: if (
pattern.category == ContentCategory.INJECTION
and not enable_injection_filter
):
continue continue
self._patterns.append(replace(pattern)) self._patterns.append(replace(pattern))
@@ -343,7 +349,10 @@ class ContentFilter:
filtered_content = content filtered_content = content
for match in all_matches: for match in all_matches:
matched_pattern = self._get_pattern(match.pattern_name) matched_pattern = self._get_pattern(match.pattern_name)
if matched_pattern and matched_pattern.action in (FilterAction.REDACT, FilterAction.BLOCK): if matched_pattern and matched_pattern.action in (
FilterAction.REDACT,
FilterAction.BLOCK,
):
filtered_content = ( filtered_content = (
filtered_content[: match.start_pos] filtered_content[: match.start_pos]
+ (match.redacted_text or "[REDACTED]") + (match.redacted_text or "[REDACTED]")
@@ -371,8 +380,12 @@ class ContentFilter:
if raise_on_block: if raise_on_block:
raise ContentFilterError( raise ContentFilterError(
block_reason or "Content blocked", block_reason or "Content blocked",
filter_type=all_matches[0].category.value if all_matches else "unknown", filter_type=all_matches[0].category.value
detected_patterns=[m.pattern_name for m in all_matches] if all_matches else [], if all_matches
else "unknown",
detected_patterns=[m.pattern_name for m in all_matches]
if all_matches
else [],
) )
elif all_matches: elif all_matches:
logger.debug( logger.debug(
@@ -480,9 +493,13 @@ class ContentFilter:
matches = pattern.find_matches(content) matches = pattern.find_matches(content)
for match in matches: for match in matches:
if pattern.action == FilterAction.BLOCK: if pattern.action == FilterAction.BLOCK:
issues.append(f"Blocked: {pattern.name} at position {match.start_pos}") issues.append(
f"Blocked: {pattern.name} at position {match.start_pos}"
)
elif pattern.action == FilterAction.WARN and not allow_warnings: elif pattern.action == FilterAction.WARN and not allow_warnings:
issues.append(f"Warning: {pattern.name} at position {match.start_pos}") issues.append(
f"Warning: {pattern.name} at position {match.start_pos}"
)
return len(issues) == 0, issues return len(issues) == 0, issues

View File

@@ -69,7 +69,9 @@ class BudgetTracker:
else 0 else 0
) )
is_warning = max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold is_warning = (
max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
)
is_exceeded = ( is_exceeded = (
self._tokens_used >= self.tokens_limit self._tokens_used >= self.tokens_limit
or self._cost_used_usd >= self.cost_limit_usd or self._cost_used_usd >= self.cost_limit_usd
@@ -94,12 +96,16 @@ class BudgetTracker:
reset_at=reset_at, reset_at=reset_at,
) )
async def check_budget(self, estimated_tokens: int, estimated_cost_usd: float) -> bool: async def check_budget(
self, estimated_tokens: int, estimated_cost_usd: float
) -> bool:
"""Check if there's enough budget for an operation.""" """Check if there's enough budget for an operation."""
async with self._lock: async with self._lock:
self._check_reset() self._check_reset()
would_exceed_tokens = (self._tokens_used + estimated_tokens) > self.tokens_limit would_exceed_tokens = (
self._tokens_used + estimated_tokens
) > self.tokens_limit
would_exceed_cost = ( would_exceed_cost = (
self._cost_used_usd + estimated_cost_usd self._cost_used_usd + estimated_cost_usd
) > self.cost_limit_usd ) > self.cost_limit_usd
@@ -241,13 +247,13 @@ class CostController:
session_tracker = await self.get_or_create_tracker( session_tracker = await self.get_or_create_tracker(
BudgetScope.SESSION, session_id BudgetScope.SESSION, session_id
) )
if not await session_tracker.check_budget(estimated_tokens, estimated_cost_usd): if not await session_tracker.check_budget(
estimated_tokens, estimated_cost_usd
):
return False return False
# Check agent daily budget # Check agent daily budget
agent_tracker = await self.get_or_create_tracker( agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
BudgetScope.DAILY, agent_id
)
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd): if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
return False return False

View File

@@ -253,7 +253,9 @@ class EmergencyControls:
self._on_resume_callbacks, self._on_resume_callbacks,
{"scope": scope, "resumed_by": resumed_by}, {"scope": scope, "resumed_by": resumed_by},
) )
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by}) await self._notify_handlers(
"resume", {"scope": scope, "resumed_by": resumed_by}
)
return True return True

View File

@@ -266,9 +266,7 @@ class SafetyGuardian:
except SafetyError as e: except SafetyError as e:
# Known safety error # Known safety error
return await self._create_denial_result( return await self._create_denial_result(action, [str(e)], audit_events)
action, [str(e)], audit_events
)
except Exception as e: except Exception as e:
# Unknown error - fail closed in strict mode # Unknown error - fail closed in strict mode
logger.error("Unexpected error in safety validation: %s", e) logger.error("Unexpected error in safety validation: %s", e)
@@ -391,7 +389,9 @@ class SafetyGuardian:
if action.tool_name: if action.tool_name:
for pattern in policy.denied_tools: for pattern in policy.denied_tools:
if self._matches_pattern(action.tool_name, pattern): if self._matches_pattern(action.tool_name, pattern):
reasons.append(f"Tool '{action.tool_name}' denied by pattern '{pattern}'") reasons.append(
f"Tool '{action.tool_name}' denied by pattern '{pattern}'"
)
return GuardianResult( return GuardianResult(
action_id=action.id, action_id=action.id,
allowed=False, allowed=False,
@@ -419,7 +419,9 @@ class SafetyGuardian:
if action.resource: if action.resource:
for pattern in policy.denied_file_patterns: for pattern in policy.denied_file_patterns:
if self._matches_pattern(action.resource, pattern): if self._matches_pattern(action.resource, pattern):
reasons.append(f"Resource '{action.resource}' denied by pattern '{pattern}'") reasons.append(
f"Resource '{action.resource}' denied by pattern '{pattern}'"
)
return GuardianResult( return GuardianResult(
action_id=action.id, action_id=action.id,
allowed=False, allowed=False,

View File

@@ -134,7 +134,9 @@ class LoopDetector:
raise LoopDetectedError( raise LoopDetectedError(
f"Loop detected: {loop_type}", f"Loop detected: {loop_type}",
loop_type=loop_type or "unknown", loop_type=loop_type or "unknown",
repetition_count=self._max_exact if loop_type == "exact" else self._max_semantic, repetition_count=self._max_exact
if loop_type == "exact"
else self._max_semantic,
action_pattern=[signature.semantic_key()], action_pattern=[signature.semantic_key()],
agent_id=action.metadata.agent_id, agent_id=action.metadata.agent_id,
action_id=action.id, action_id=action.id,

View File

@@ -293,7 +293,9 @@ class MCPSafetyWrapper:
action_type=action_type, action_type=action_type,
tool_name=tool_call.tool_name, tool_name=tool_call.tool_name,
arguments=tool_call.arguments, arguments=tool_call.arguments,
resource=tool_call.arguments.get("path", tool_call.arguments.get("resource")), resource=tool_call.arguments.get(
"path", tool_call.arguments.get("resource")
),
metadata=metadata, metadata=metadata,
) )
@@ -302,7 +304,9 @@ class MCPSafetyWrapper:
tool_lower = tool_name.lower() tool_lower = tool_name.lower()
# Check destructive patterns # Check destructive patterns
if any(d in tool_lower for d in ["write", "create", "delete", "remove", "update"]): if any(
d in tool_lower for d in ["write", "create", "delete", "remove", "update"]
):
if "file" in tool_lower: if "file" in tool_lower:
if "delete" in tool_lower or "remove" in tool_lower: if "delete" in tool_lower or "remove" in tool_lower:
return ActionType.FILE_DELETE return ActionType.FILE_DELETE

View File

@@ -69,7 +69,18 @@ class SafetyMetrics:
def _init_histogram_buckets(self) -> None: def _init_histogram_buckets(self) -> None:
"""Initialize histogram buckets for latency metrics.""" """Initialize histogram buckets for latency metrics."""
latency_buckets = [0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, float("inf")] latency_buckets = [
0.01,
0.05,
0.1,
0.25,
0.5,
1.0,
2.5,
5.0,
10.0,
float("inf"),
]
for name in [ for name in [
"validation_latency_seconds", "validation_latency_seconds",
@@ -321,7 +332,8 @@ class SafetyMetrics:
async with self._lock: async with self._lock:
total_validations = sum(self._counters["safety_validations_total"].values()) total_validations = sum(self._counters["safety_validations_total"].values())
denied_validations = sum( denied_validations = sum(
v for k, v in self._counters["safety_validations_total"].items() v
for k, v in self._counters["safety_validations_total"].items()
if "decision=deny" in k if "decision=deny" in k
) )
@@ -358,11 +370,13 @@ class SafetyMetrics:
"rollbacks_executed": sum( "rollbacks_executed": sum(
self._counters["safety_rollbacks_total"].values() self._counters["safety_rollbacks_total"].values()
), ),
"mcp_calls": sum( "mcp_calls": sum(self._counters["safety_mcp_calls_total"].values()),
self._counters["safety_mcp_calls_total"].values() "pending_approvals": self._gauges.get(
), "safety_pending_approvals", {}
"pending_approvals": self._gauges.get("safety_pending_approvals", {}).get("", 0), ).get("", 0),
"active_checkpoints": self._gauges.get("safety_active_checkpoints", {}).get("", 0), "active_checkpoints": self._gauges.get(
"safety_active_checkpoints", {}
).get("", 0),
} }
async def reset(self) -> None: async def reset(self) -> None:

View File

@@ -212,9 +212,7 @@ class ValidationResult(BaseModel):
applied_rules: list[str] = Field( applied_rules: list[str] = Field(
default_factory=list, description="IDs of applied rules" default_factory=list, description="IDs of applied rules"
) )
reasons: list[str] = Field( reasons: list[str] = Field(default_factory=list, description="Reasons for decision")
default_factory=list, description="Reasons for decision"
)
approval_id: str | None = Field(None, description="Approval request ID if needed") approval_id: str | None = Field(None, description="Approval request ID if needed")
retry_after_seconds: float | None = Field( retry_after_seconds: float | None = Field(
None, description="Retry delay if rate limited" None, description="Retry delay if rate limited"
@@ -267,9 +265,7 @@ class RateLimitConfig(BaseModel):
limit: int = Field(..., description="Maximum allowed in window") limit: int = Field(..., description="Maximum allowed in window")
window_seconds: int = Field(60, description="Time window in seconds") window_seconds: int = Field(60, description="Time window in seconds")
burst_limit: int | None = Field(None, description="Burst allowance") burst_limit: int | None = Field(None, description="Burst allowance")
slowdown_threshold: float = Field( slowdown_threshold: float = Field(0.8, description="Start slowing at this fraction")
0.8, description="Start slowing at this fraction"
)
class RateLimitStatus(BaseModel): class RateLimitStatus(BaseModel):

View File

@@ -66,9 +66,7 @@ class RollbackManager:
""" """
config = get_safety_config() config = get_safety_config()
self._checkpoint_dir = Path( self._checkpoint_dir = Path(checkpoint_dir or config.checkpoint_dir)
checkpoint_dir or config.checkpoint_dir
)
self._retention_hours = retention_hours or config.checkpoint_retention_hours self._retention_hours = retention_hours or config.checkpoint_retention_hours
self._checkpoints: dict[str, Checkpoint] = {} self._checkpoints: dict[str, Checkpoint] = {}
@@ -231,7 +229,9 @@ class RollbackManager:
success=success, success=success,
actions_rolled_back=actions_rolled_back, actions_rolled_back=actions_rolled_back,
failed_actions=failed_actions, failed_actions=failed_actions,
error=None if success else f"Failed to rollback {len(failed_actions)} items", error=None
if success
else f"Failed to rollback {len(failed_actions)} items",
) )
if success: if success:
@@ -294,8 +294,7 @@ class RollbackManager:
if not include_expired: if not include_expired:
checkpoints = [ checkpoints = [
c for c in checkpoints c for c in checkpoints if c.expires_at is None or c.expires_at > now
if c.expires_at is None or c.expires_at > now
] ]
return checkpoints return checkpoints

View File

@@ -113,7 +113,9 @@ class ActionValidator:
self._rules.append(rule) self._rules.append(rule)
# Re-sort by priority (higher first) # Re-sort by priority (higher first)
self._rules.sort(key=lambda r: r.priority, reverse=True) self._rules.sort(key=lambda r: r.priority, reverse=True)
logger.debug("Added validation rule: %s (priority %d)", rule.name, rule.priority) logger.debug(
"Added validation rule: %s (priority %d)", rule.name, rule.priority
)
def remove_rule(self, rule_id: str) -> bool: def remove_rule(self, rule_id: str) -> bool:
""" """

View File

@@ -44,8 +44,8 @@ def mock_superuser():
@pytest.fixture @pytest.fixture
def client(mock_mcp_client, mock_superuser): def client(mock_mcp_client, mock_superuser):
"""Create a FastAPI test client with mocked dependencies.""" """Create a FastAPI test client with mocked dependencies."""
from app.api.routes.mcp import get_mcp_client
from app.api.dependencies.permissions import require_superuser from app.api.dependencies.permissions import require_superuser
from app.api.routes.mcp import get_mcp_client
# Override dependencies # Override dependencies
async def override_get_mcp_client(): async def override_get_mcp_client():

View File

@@ -14,7 +14,6 @@ from app.services.mcp.client_manager import (
shutdown_mcp_client, shutdown_mcp_client,
) )
from app.services.mcp.config import MCPConfig, MCPServerConfig from app.services.mcp.config import MCPConfig, MCPServerConfig
from app.services.mcp.connection import ConnectionState
from app.services.mcp.exceptions import MCPServerNotFoundError from app.services.mcp.exceptions import MCPServerNotFoundError
from app.services.mcp.registry import MCPServerRegistry from app.services.mcp.registry import MCPServerRegistry
from app.services.mcp.routing import ToolInfo, ToolResult from app.services.mcp.routing import ToolInfo, ToolResult

View File

@@ -4,10 +4,8 @@ Tests for MCP Configuration System
import os import os
import tempfile import tempfile
from pathlib import Path
import pytest import pytest
import yaml
from app.services.mcp.config import ( from app.services.mcp.config import (
MCPConfig, MCPConfig,
@@ -217,9 +215,7 @@ mcp_servers:
default_timeout: 60 default_timeout: 60
connection_pool_size: 20 connection_pool_size: 20
""" """
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
mode="w", suffix=".yaml", delete=False
) as f:
f.write(yaml_content) f.write(yaml_content)
f.flush() f.flush()
@@ -248,9 +244,7 @@ mcp_servers:
explicit-server: explicit-server:
url: http://explicit:8000 url: http://explicit:8000
""" """
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
mode="w", suffix=".yaml", delete=False
) as f:
f.write(yaml_content) f.write(yaml_content)
f.flush() f.flush()
@@ -267,9 +261,7 @@ mcp_servers:
env-server: env-server:
url: http://env:8000 url: http://env:8000
""" """
with tempfile.NamedTemporaryFile( with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
mode="w", suffix=".yaml", delete=False
) as f:
f.write(yaml_content) f.write(yaml_content)
f.flush() f.flush()

View File

@@ -220,9 +220,7 @@ class TestMCPConnection:
MockClient.return_value = mock_client MockClient.return_value = mock_client
await conn.connect() await conn.connect()
result = await conn.execute_request( result = await conn.execute_request("POST", "/mcp", data={"method": "test"})
"POST", "/mcp", data={"method": "test"}
)
assert result == {"result": "success"} assert result == {"result": "success"}

View File

@@ -160,11 +160,21 @@ class TestMCPToolNotFoundError:
"""Test tool not found with available tools listed.""" """Test tool not found with available tools listed."""
error = MCPToolNotFoundError( error = MCPToolNotFoundError(
"unknown-tool", "unknown-tool",
available_tools=["tool-1", "tool-2", "tool-3", "tool-4", "tool-5", "tool-6"], available_tools=[
"tool-1",
"tool-2",
"tool-3",
"tool-4",
"tool-5",
"tool-6",
],
) )
assert len(error.available_tools) == 6 assert len(error.available_tools) == 6
# Should show first 5 tools with ellipsis # Should show first 5 tools with ellipsis
assert "available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..." in str(error) assert (
"available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..."
in str(error)
)
class TestMCPCircuitOpenError: class TestMCPCircuitOpenError:

View File

@@ -4,7 +4,7 @@ Tests for MCP Server Registry
import pytest import pytest
from app.services.mcp.config import MCPConfig, MCPServerConfig, TransportType from app.services.mcp.config import MCPConfig, MCPServerConfig
from app.services.mcp.exceptions import MCPServerNotFoundError from app.services.mcp.exceptions import MCPServerNotFoundError
from app.services.mcp.registry import ( from app.services.mcp.registry import (
MCPServerRegistry, MCPServerRegistry,

View File

@@ -220,7 +220,9 @@ class TestScan:
filter_all: ContentFilter, filter_all: ContentFilter,
) -> None: ) -> None:
"""Test scanning for specific categories only.""" """Test scanning for specific categories only."""
content = "Email: test@example.com, token: ghp_abc123456789012345678901234567890123" content = (
"Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
)
# Scan only for secrets # Scan only for secrets
matches = await filter_all.scan( matches = await filter_all.scan(

View File

@@ -321,8 +321,7 @@ class TestLoadRulesFromPolicy:
validator.load_rules_from_policy(policy) validator.load_rules_from_policy(policy)
approval_rules = [ approval_rules = [
r for r in validator._rules r for r in validator._rules if r.decision == SafetyDecision.REQUIRE_APPROVAL
if r.decision == SafetyDecision.REQUIRE_APPROVAL
] ]
assert len(approval_rules) == 1 assert len(approval_rules) == 1

View File

@@ -162,7 +162,9 @@ export default function ProjectSettingsPage({ params }: ProjectSettingsPageProps
<Card> <Card>
<CardHeader> <CardHeader>
<CardTitle>Autonomy Level</CardTitle> <CardTitle>Autonomy Level</CardTitle>
<CardDescription>Control how much oversight you want over agent actions</CardDescription> <CardDescription>
Control how much oversight you want over agent actions
</CardDescription>
</CardHeader> </CardHeader>
<CardContent className="space-y-4"> <CardContent className="space-y-4">
<div className="space-y-2"> <div className="space-y-2">