refactor(safety): apply consistent formatting across services and tests

Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
This commit is contained in:
2026-01-03 16:23:39 +01:00
parent 065e43c5a9
commit 520c06175e
23 changed files with 123 additions and 81 deletions

View File

@@ -74,7 +74,9 @@ class ToolInfoResponse(BaseModel):
name: str = Field(..., description="Tool name")
description: str | None = Field(None, description="Tool description")
server_name: str | None = Field(None, description="Server providing the tool")
input_schema: dict[str, Any] | None = Field(None, description="JSON schema for input")
input_schema: dict[str, Any] | None = Field(
None, description="JSON schema for input"
)
class ToolListResponse(BaseModel):

View File

@@ -158,9 +158,7 @@ class MCPConfig(BaseModel):
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
"""Get all enabled server configurations."""
return {
name: config
for name, config in self.mcp_servers.items()
if config.enabled
name: config for name, config in self.mcp_servers.items() if config.enabled
}
def list_server_names(self) -> list[str]:

View File

@@ -196,9 +196,7 @@ class AuditLogger:
) -> AuditEvent:
"""Log an action execution result."""
event_type = (
AuditEventType.ACTION_EXECUTED
if success
else AuditEventType.ACTION_FAILED
AuditEventType.ACTION_EXECUTED if success else AuditEventType.ACTION_FAILED
)
return await self.log(
@@ -477,9 +475,7 @@ class AuditLogger:
"user_id": event.user_id,
"decision": event.decision.value if event.decision else None,
"details": {
k: v
for k, v in event.details.items()
if not k.startswith("_")
k: v for k, v in event.details.items() if not k.startswith("_")
},
"correlation_id": event.correlation_id,
}

View File

@@ -31,9 +31,7 @@ class SafetyConfig(BaseSettings):
# General settings
enabled: bool = Field(True, description="Enable safety framework")
strict_mode: bool = Field(
True, description="Strict mode (fail closed on errors)"
)
strict_mode: bool = Field(True, description="Strict mode (fail closed on errors)")
log_level: str = Field("INFO", description="Logging level")
# Default autonomy level
@@ -255,7 +253,8 @@ def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
max_tokens_per_day=base_policy.max_tokens_per_day // 10,
max_actions_per_minute=base_policy.max_actions_per_minute // 2,
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute // 2,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute
// 2,
denied_tools=["delete_*", "destroy_*", "drop_*"],
)
@@ -294,7 +293,8 @@ def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
max_tokens_per_day=base_policy.max_tokens_per_day * 5,
max_actions_per_minute=base_policy.max_actions_per_minute * 2,
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute * 2,
max_file_operations_per_minute=base_policy.max_file_operations_per_minute
* 2,
)

View File

@@ -260,9 +260,15 @@ class ContentFilter:
continue
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
continue
if pattern.category == ContentCategory.CREDENTIALS and not enable_secret_filter:
if (
pattern.category == ContentCategory.CREDENTIALS
and not enable_secret_filter
):
continue
if pattern.category == ContentCategory.INJECTION and not enable_injection_filter:
if (
pattern.category == ContentCategory.INJECTION
and not enable_injection_filter
):
continue
self._patterns.append(replace(pattern))
@@ -343,7 +349,10 @@ class ContentFilter:
filtered_content = content
for match in all_matches:
matched_pattern = self._get_pattern(match.pattern_name)
if matched_pattern and matched_pattern.action in (FilterAction.REDACT, FilterAction.BLOCK):
if matched_pattern and matched_pattern.action in (
FilterAction.REDACT,
FilterAction.BLOCK,
):
filtered_content = (
filtered_content[: match.start_pos]
+ (match.redacted_text or "[REDACTED]")
@@ -371,8 +380,12 @@ class ContentFilter:
if raise_on_block:
raise ContentFilterError(
block_reason or "Content blocked",
filter_type=all_matches[0].category.value if all_matches else "unknown",
detected_patterns=[m.pattern_name for m in all_matches] if all_matches else [],
filter_type=all_matches[0].category.value
if all_matches
else "unknown",
detected_patterns=[m.pattern_name for m in all_matches]
if all_matches
else [],
)
elif all_matches:
logger.debug(
@@ -480,9 +493,13 @@ class ContentFilter:
matches = pattern.find_matches(content)
for match in matches:
if pattern.action == FilterAction.BLOCK:
issues.append(f"Blocked: {pattern.name} at position {match.start_pos}")
issues.append(
f"Blocked: {pattern.name} at position {match.start_pos}"
)
elif pattern.action == FilterAction.WARN and not allow_warnings:
issues.append(f"Warning: {pattern.name} at position {match.start_pos}")
issues.append(
f"Warning: {pattern.name} at position {match.start_pos}"
)
return len(issues) == 0, issues

View File

@@ -69,7 +69,9 @@ class BudgetTracker:
else 0
)
is_warning = max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
is_warning = (
max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
)
is_exceeded = (
self._tokens_used >= self.tokens_limit
or self._cost_used_usd >= self.cost_limit_usd
@@ -94,12 +96,16 @@ class BudgetTracker:
reset_at=reset_at,
)
async def check_budget(self, estimated_tokens: int, estimated_cost_usd: float) -> bool:
async def check_budget(
self, estimated_tokens: int, estimated_cost_usd: float
) -> bool:
"""Check if there's enough budget for an operation."""
async with self._lock:
self._check_reset()
would_exceed_tokens = (self._tokens_used + estimated_tokens) > self.tokens_limit
would_exceed_tokens = (
self._tokens_used + estimated_tokens
) > self.tokens_limit
would_exceed_cost = (
self._cost_used_usd + estimated_cost_usd
) > self.cost_limit_usd
@@ -241,13 +247,13 @@ class CostController:
session_tracker = await self.get_or_create_tracker(
BudgetScope.SESSION, session_id
)
if not await session_tracker.check_budget(estimated_tokens, estimated_cost_usd):
if not await session_tracker.check_budget(
estimated_tokens, estimated_cost_usd
):
return False
# Check agent daily budget
agent_tracker = await self.get_or_create_tracker(
BudgetScope.DAILY, agent_id
)
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
return False

View File

@@ -253,7 +253,9 @@ class EmergencyControls:
self._on_resume_callbacks,
{"scope": scope, "resumed_by": resumed_by},
)
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by})
await self._notify_handlers(
"resume", {"scope": scope, "resumed_by": resumed_by}
)
return True

View File

@@ -266,9 +266,7 @@ class SafetyGuardian:
except SafetyError as e:
# Known safety error
return await self._create_denial_result(
action, [str(e)], audit_events
)
return await self._create_denial_result(action, [str(e)], audit_events)
except Exception as e:
# Unknown error - fail closed in strict mode
logger.error("Unexpected error in safety validation: %s", e)
@@ -391,7 +389,9 @@ class SafetyGuardian:
if action.tool_name:
for pattern in policy.denied_tools:
if self._matches_pattern(action.tool_name, pattern):
reasons.append(f"Tool '{action.tool_name}' denied by pattern '{pattern}'")
reasons.append(
f"Tool '{action.tool_name}' denied by pattern '{pattern}'"
)
return GuardianResult(
action_id=action.id,
allowed=False,
@@ -419,7 +419,9 @@ class SafetyGuardian:
if action.resource:
for pattern in policy.denied_file_patterns:
if self._matches_pattern(action.resource, pattern):
reasons.append(f"Resource '{action.resource}' denied by pattern '{pattern}'")
reasons.append(
f"Resource '{action.resource}' denied by pattern '{pattern}'"
)
return GuardianResult(
action_id=action.id,
allowed=False,

View File

@@ -134,7 +134,9 @@ class LoopDetector:
raise LoopDetectedError(
f"Loop detected: {loop_type}",
loop_type=loop_type or "unknown",
repetition_count=self._max_exact if loop_type == "exact" else self._max_semantic,
repetition_count=self._max_exact
if loop_type == "exact"
else self._max_semantic,
action_pattern=[signature.semantic_key()],
agent_id=action.metadata.agent_id,
action_id=action.id,

View File

@@ -293,7 +293,9 @@ class MCPSafetyWrapper:
action_type=action_type,
tool_name=tool_call.tool_name,
arguments=tool_call.arguments,
resource=tool_call.arguments.get("path", tool_call.arguments.get("resource")),
resource=tool_call.arguments.get(
"path", tool_call.arguments.get("resource")
),
metadata=metadata,
)
@@ -302,7 +304,9 @@ class MCPSafetyWrapper:
tool_lower = tool_name.lower()
# Check destructive patterns
if any(d in tool_lower for d in ["write", "create", "delete", "remove", "update"]):
if any(
d in tool_lower for d in ["write", "create", "delete", "remove", "update"]
):
if "file" in tool_lower:
if "delete" in tool_lower or "remove" in tool_lower:
return ActionType.FILE_DELETE

View File

@@ -69,7 +69,18 @@ class SafetyMetrics:
def _init_histogram_buckets(self) -> None:
"""Initialize histogram buckets for latency metrics."""
latency_buckets = [0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, float("inf")]
latency_buckets = [
0.01,
0.05,
0.1,
0.25,
0.5,
1.0,
2.5,
5.0,
10.0,
float("inf"),
]
for name in [
"validation_latency_seconds",
@@ -321,7 +332,8 @@ class SafetyMetrics:
async with self._lock:
total_validations = sum(self._counters["safety_validations_total"].values())
denied_validations = sum(
v for k, v in self._counters["safety_validations_total"].items()
v
for k, v in self._counters["safety_validations_total"].items()
if "decision=deny" in k
)
@@ -358,11 +370,13 @@ class SafetyMetrics:
"rollbacks_executed": sum(
self._counters["safety_rollbacks_total"].values()
),
"mcp_calls": sum(
self._counters["safety_mcp_calls_total"].values()
),
"pending_approvals": self._gauges.get("safety_pending_approvals", {}).get("", 0),
"active_checkpoints": self._gauges.get("safety_active_checkpoints", {}).get("", 0),
"mcp_calls": sum(self._counters["safety_mcp_calls_total"].values()),
"pending_approvals": self._gauges.get(
"safety_pending_approvals", {}
).get("", 0),
"active_checkpoints": self._gauges.get(
"safety_active_checkpoints", {}
).get("", 0),
}
async def reset(self) -> None:

View File

@@ -212,9 +212,7 @@ class ValidationResult(BaseModel):
applied_rules: list[str] = Field(
default_factory=list, description="IDs of applied rules"
)
reasons: list[str] = Field(
default_factory=list, description="Reasons for decision"
)
reasons: list[str] = Field(default_factory=list, description="Reasons for decision")
approval_id: str | None = Field(None, description="Approval request ID if needed")
retry_after_seconds: float | None = Field(
None, description="Retry delay if rate limited"
@@ -267,9 +265,7 @@ class RateLimitConfig(BaseModel):
limit: int = Field(..., description="Maximum allowed in window")
window_seconds: int = Field(60, description="Time window in seconds")
burst_limit: int | None = Field(None, description="Burst allowance")
slowdown_threshold: float = Field(
0.8, description="Start slowing at this fraction"
)
slowdown_threshold: float = Field(0.8, description="Start slowing at this fraction")
class RateLimitStatus(BaseModel):

View File

@@ -66,9 +66,7 @@ class RollbackManager:
"""
config = get_safety_config()
self._checkpoint_dir = Path(
checkpoint_dir or config.checkpoint_dir
)
self._checkpoint_dir = Path(checkpoint_dir or config.checkpoint_dir)
self._retention_hours = retention_hours or config.checkpoint_retention_hours
self._checkpoints: dict[str, Checkpoint] = {}
@@ -231,7 +229,9 @@ class RollbackManager:
success=success,
actions_rolled_back=actions_rolled_back,
failed_actions=failed_actions,
error=None if success else f"Failed to rollback {len(failed_actions)} items",
error=None
if success
else f"Failed to rollback {len(failed_actions)} items",
)
if success:
@@ -294,8 +294,7 @@ class RollbackManager:
if not include_expired:
checkpoints = [
c for c in checkpoints
if c.expires_at is None or c.expires_at > now
c for c in checkpoints if c.expires_at is None or c.expires_at > now
]
return checkpoints

View File

@@ -113,7 +113,9 @@ class ActionValidator:
self._rules.append(rule)
# Re-sort by priority (higher first)
self._rules.sort(key=lambda r: r.priority, reverse=True)
logger.debug("Added validation rule: %s (priority %d)", rule.name, rule.priority)
logger.debug(
"Added validation rule: %s (priority %d)", rule.name, rule.priority
)
def remove_rule(self, rule_id: str) -> bool:
"""

View File

@@ -44,8 +44,8 @@ def mock_superuser():
@pytest.fixture
def client(mock_mcp_client, mock_superuser):
"""Create a FastAPI test client with mocked dependencies."""
from app.api.routes.mcp import get_mcp_client
from app.api.dependencies.permissions import require_superuser
from app.api.routes.mcp import get_mcp_client
# Override dependencies
async def override_get_mcp_client():

View File

@@ -14,7 +14,6 @@ from app.services.mcp.client_manager import (
shutdown_mcp_client,
)
from app.services.mcp.config import MCPConfig, MCPServerConfig
from app.services.mcp.connection import ConnectionState
from app.services.mcp.exceptions import MCPServerNotFoundError
from app.services.mcp.registry import MCPServerRegistry
from app.services.mcp.routing import ToolInfo, ToolResult

View File

@@ -4,10 +4,8 @@ Tests for MCP Configuration System
import os
import tempfile
from pathlib import Path
import pytest
import yaml
from app.services.mcp.config import (
MCPConfig,
@@ -217,9 +215,7 @@ mcp_servers:
default_timeout: 60
connection_pool_size: 20
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()
@@ -248,9 +244,7 @@ mcp_servers:
explicit-server:
url: http://explicit:8000
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()
@@ -267,9 +261,7 @@ mcp_servers:
env-server:
url: http://env:8000
"""
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
f.write(yaml_content)
f.flush()

View File

@@ -220,9 +220,7 @@ class TestMCPConnection:
MockClient.return_value = mock_client
await conn.connect()
result = await conn.execute_request(
"POST", "/mcp", data={"method": "test"}
)
result = await conn.execute_request("POST", "/mcp", data={"method": "test"})
assert result == {"result": "success"}

View File

@@ -160,11 +160,21 @@ class TestMCPToolNotFoundError:
"""Test tool not found with available tools listed."""
error = MCPToolNotFoundError(
"unknown-tool",
available_tools=["tool-1", "tool-2", "tool-3", "tool-4", "tool-5", "tool-6"],
available_tools=[
"tool-1",
"tool-2",
"tool-3",
"tool-4",
"tool-5",
"tool-6",
],
)
assert len(error.available_tools) == 6
# Should show first 5 tools with ellipsis
assert "available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..." in str(error)
assert (
"available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..."
in str(error)
)
class TestMCPCircuitOpenError:

View File

@@ -4,7 +4,7 @@ Tests for MCP Server Registry
import pytest
from app.services.mcp.config import MCPConfig, MCPServerConfig, TransportType
from app.services.mcp.config import MCPConfig, MCPServerConfig
from app.services.mcp.exceptions import MCPServerNotFoundError
from app.services.mcp.registry import (
MCPServerRegistry,

View File

@@ -220,7 +220,9 @@ class TestScan:
filter_all: ContentFilter,
) -> None:
"""Test scanning for specific categories only."""
content = "Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
content = (
"Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
)
# Scan only for secrets
matches = await filter_all.scan(

View File

@@ -321,8 +321,7 @@ class TestLoadRulesFromPolicy:
validator.load_rules_from_policy(policy)
approval_rules = [
r for r in validator._rules
if r.decision == SafetyDecision.REQUIRE_APPROVAL
r for r in validator._rules if r.decision == SafetyDecision.REQUIRE_APPROVAL
]
assert len(approval_rules) == 1

View File

@@ -162,7 +162,9 @@ export default function ProjectSettingsPage({ params }: ProjectSettingsPageProps
<Card>
<CardHeader>
<CardTitle>Autonomy Level</CardTitle>
<CardDescription>Control how much oversight you want over agent actions</CardDescription>
<CardDescription>
Control how much oversight you want over agent actions
</CardDescription>
</CardHeader>
<CardContent className="space-y-4">
<div className="space-y-2">