forked from cardosofelipe/fast-next-template
refactor(safety): apply consistent formatting across services and tests
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
This commit is contained in:
@@ -74,7 +74,9 @@ class ToolInfoResponse(BaseModel):
|
||||
name: str = Field(..., description="Tool name")
|
||||
description: str | None = Field(None, description="Tool description")
|
||||
server_name: str | None = Field(None, description="Server providing the tool")
|
||||
input_schema: dict[str, Any] | None = Field(None, description="JSON schema for input")
|
||||
input_schema: dict[str, Any] | None = Field(
|
||||
None, description="JSON schema for input"
|
||||
)
|
||||
|
||||
|
||||
class ToolListResponse(BaseModel):
|
||||
|
||||
@@ -158,9 +158,7 @@ class MCPConfig(BaseModel):
|
||||
def get_enabled_servers(self) -> dict[str, MCPServerConfig]:
|
||||
"""Get all enabled server configurations."""
|
||||
return {
|
||||
name: config
|
||||
for name, config in self.mcp_servers.items()
|
||||
if config.enabled
|
||||
name: config for name, config in self.mcp_servers.items() if config.enabled
|
||||
}
|
||||
|
||||
def list_server_names(self) -> list[str]:
|
||||
|
||||
@@ -196,9 +196,7 @@ class AuditLogger:
|
||||
) -> AuditEvent:
|
||||
"""Log an action execution result."""
|
||||
event_type = (
|
||||
AuditEventType.ACTION_EXECUTED
|
||||
if success
|
||||
else AuditEventType.ACTION_FAILED
|
||||
AuditEventType.ACTION_EXECUTED if success else AuditEventType.ACTION_FAILED
|
||||
)
|
||||
|
||||
return await self.log(
|
||||
@@ -477,9 +475,7 @@ class AuditLogger:
|
||||
"user_id": event.user_id,
|
||||
"decision": event.decision.value if event.decision else None,
|
||||
"details": {
|
||||
k: v
|
||||
for k, v in event.details.items()
|
||||
if not k.startswith("_")
|
||||
k: v for k, v in event.details.items() if not k.startswith("_")
|
||||
},
|
||||
"correlation_id": event.correlation_id,
|
||||
}
|
||||
|
||||
@@ -31,9 +31,7 @@ class SafetyConfig(BaseSettings):
|
||||
|
||||
# General settings
|
||||
enabled: bool = Field(True, description="Enable safety framework")
|
||||
strict_mode: bool = Field(
|
||||
True, description="Strict mode (fail closed on errors)"
|
||||
)
|
||||
strict_mode: bool = Field(True, description="Strict mode (fail closed on errors)")
|
||||
log_level: str = Field("INFO", description="Logging level")
|
||||
|
||||
# Default autonomy level
|
||||
@@ -255,7 +253,8 @@ def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day // 10,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute // 2,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute // 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute // 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute
|
||||
// 2,
|
||||
denied_tools=["delete_*", "destroy_*", "drop_*"],
|
||||
)
|
||||
|
||||
@@ -294,7 +293,8 @@ def get_policy_for_autonomy_level(level: AutonomyLevel) -> SafetyPolicy:
|
||||
max_tokens_per_day=base_policy.max_tokens_per_day * 5,
|
||||
max_actions_per_minute=base_policy.max_actions_per_minute * 2,
|
||||
max_llm_calls_per_minute=base_policy.max_llm_calls_per_minute * 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute * 2,
|
||||
max_file_operations_per_minute=base_policy.max_file_operations_per_minute
|
||||
* 2,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -260,9 +260,15 @@ class ContentFilter:
|
||||
continue
|
||||
if pattern.category == ContentCategory.SECRETS and not enable_secret_filter:
|
||||
continue
|
||||
if pattern.category == ContentCategory.CREDENTIALS and not enable_secret_filter:
|
||||
if (
|
||||
pattern.category == ContentCategory.CREDENTIALS
|
||||
and not enable_secret_filter
|
||||
):
|
||||
continue
|
||||
if pattern.category == ContentCategory.INJECTION and not enable_injection_filter:
|
||||
if (
|
||||
pattern.category == ContentCategory.INJECTION
|
||||
and not enable_injection_filter
|
||||
):
|
||||
continue
|
||||
self._patterns.append(replace(pattern))
|
||||
|
||||
@@ -343,7 +349,10 @@ class ContentFilter:
|
||||
filtered_content = content
|
||||
for match in all_matches:
|
||||
matched_pattern = self._get_pattern(match.pattern_name)
|
||||
if matched_pattern and matched_pattern.action in (FilterAction.REDACT, FilterAction.BLOCK):
|
||||
if matched_pattern and matched_pattern.action in (
|
||||
FilterAction.REDACT,
|
||||
FilterAction.BLOCK,
|
||||
):
|
||||
filtered_content = (
|
||||
filtered_content[: match.start_pos]
|
||||
+ (match.redacted_text or "[REDACTED]")
|
||||
@@ -371,8 +380,12 @@ class ContentFilter:
|
||||
if raise_on_block:
|
||||
raise ContentFilterError(
|
||||
block_reason or "Content blocked",
|
||||
filter_type=all_matches[0].category.value if all_matches else "unknown",
|
||||
detected_patterns=[m.pattern_name for m in all_matches] if all_matches else [],
|
||||
filter_type=all_matches[0].category.value
|
||||
if all_matches
|
||||
else "unknown",
|
||||
detected_patterns=[m.pattern_name for m in all_matches]
|
||||
if all_matches
|
||||
else [],
|
||||
)
|
||||
elif all_matches:
|
||||
logger.debug(
|
||||
@@ -480,9 +493,13 @@ class ContentFilter:
|
||||
matches = pattern.find_matches(content)
|
||||
for match in matches:
|
||||
if pattern.action == FilterAction.BLOCK:
|
||||
issues.append(f"Blocked: {pattern.name} at position {match.start_pos}")
|
||||
issues.append(
|
||||
f"Blocked: {pattern.name} at position {match.start_pos}"
|
||||
)
|
||||
elif pattern.action == FilterAction.WARN and not allow_warnings:
|
||||
issues.append(f"Warning: {pattern.name} at position {match.start_pos}")
|
||||
issues.append(
|
||||
f"Warning: {pattern.name} at position {match.start_pos}"
|
||||
)
|
||||
|
||||
return len(issues) == 0, issues
|
||||
|
||||
|
||||
@@ -69,7 +69,9 @@ class BudgetTracker:
|
||||
else 0
|
||||
)
|
||||
|
||||
is_warning = max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
|
||||
is_warning = (
|
||||
max(token_usage_ratio, cost_usage_ratio) >= self.warning_threshold
|
||||
)
|
||||
is_exceeded = (
|
||||
self._tokens_used >= self.tokens_limit
|
||||
or self._cost_used_usd >= self.cost_limit_usd
|
||||
@@ -94,12 +96,16 @@ class BudgetTracker:
|
||||
reset_at=reset_at,
|
||||
)
|
||||
|
||||
async def check_budget(self, estimated_tokens: int, estimated_cost_usd: float) -> bool:
|
||||
async def check_budget(
|
||||
self, estimated_tokens: int, estimated_cost_usd: float
|
||||
) -> bool:
|
||||
"""Check if there's enough budget for an operation."""
|
||||
async with self._lock:
|
||||
self._check_reset()
|
||||
|
||||
would_exceed_tokens = (self._tokens_used + estimated_tokens) > self.tokens_limit
|
||||
would_exceed_tokens = (
|
||||
self._tokens_used + estimated_tokens
|
||||
) > self.tokens_limit
|
||||
would_exceed_cost = (
|
||||
self._cost_used_usd + estimated_cost_usd
|
||||
) > self.cost_limit_usd
|
||||
@@ -241,13 +247,13 @@ class CostController:
|
||||
session_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.SESSION, session_id
|
||||
)
|
||||
if not await session_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
||||
if not await session_tracker.check_budget(
|
||||
estimated_tokens, estimated_cost_usd
|
||||
):
|
||||
return False
|
||||
|
||||
# Check agent daily budget
|
||||
agent_tracker = await self.get_or_create_tracker(
|
||||
BudgetScope.DAILY, agent_id
|
||||
)
|
||||
agent_tracker = await self.get_or_create_tracker(BudgetScope.DAILY, agent_id)
|
||||
if not await agent_tracker.check_budget(estimated_tokens, estimated_cost_usd):
|
||||
return False
|
||||
|
||||
|
||||
@@ -253,7 +253,9 @@ class EmergencyControls:
|
||||
self._on_resume_callbacks,
|
||||
{"scope": scope, "resumed_by": resumed_by},
|
||||
)
|
||||
await self._notify_handlers("resume", {"scope": scope, "resumed_by": resumed_by})
|
||||
await self._notify_handlers(
|
||||
"resume", {"scope": scope, "resumed_by": resumed_by}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -266,9 +266,7 @@ class SafetyGuardian:
|
||||
|
||||
except SafetyError as e:
|
||||
# Known safety error
|
||||
return await self._create_denial_result(
|
||||
action, [str(e)], audit_events
|
||||
)
|
||||
return await self._create_denial_result(action, [str(e)], audit_events)
|
||||
except Exception as e:
|
||||
# Unknown error - fail closed in strict mode
|
||||
logger.error("Unexpected error in safety validation: %s", e)
|
||||
@@ -391,7 +389,9 @@ class SafetyGuardian:
|
||||
if action.tool_name:
|
||||
for pattern in policy.denied_tools:
|
||||
if self._matches_pattern(action.tool_name, pattern):
|
||||
reasons.append(f"Tool '{action.tool_name}' denied by pattern '{pattern}'")
|
||||
reasons.append(
|
||||
f"Tool '{action.tool_name}' denied by pattern '{pattern}'"
|
||||
)
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
@@ -419,7 +419,9 @@ class SafetyGuardian:
|
||||
if action.resource:
|
||||
for pattern in policy.denied_file_patterns:
|
||||
if self._matches_pattern(action.resource, pattern):
|
||||
reasons.append(f"Resource '{action.resource}' denied by pattern '{pattern}'")
|
||||
reasons.append(
|
||||
f"Resource '{action.resource}' denied by pattern '{pattern}'"
|
||||
)
|
||||
return GuardianResult(
|
||||
action_id=action.id,
|
||||
allowed=False,
|
||||
|
||||
@@ -134,7 +134,9 @@ class LoopDetector:
|
||||
raise LoopDetectedError(
|
||||
f"Loop detected: {loop_type}",
|
||||
loop_type=loop_type or "unknown",
|
||||
repetition_count=self._max_exact if loop_type == "exact" else self._max_semantic,
|
||||
repetition_count=self._max_exact
|
||||
if loop_type == "exact"
|
||||
else self._max_semantic,
|
||||
action_pattern=[signature.semantic_key()],
|
||||
agent_id=action.metadata.agent_id,
|
||||
action_id=action.id,
|
||||
|
||||
@@ -293,7 +293,9 @@ class MCPSafetyWrapper:
|
||||
action_type=action_type,
|
||||
tool_name=tool_call.tool_name,
|
||||
arguments=tool_call.arguments,
|
||||
resource=tool_call.arguments.get("path", tool_call.arguments.get("resource")),
|
||||
resource=tool_call.arguments.get(
|
||||
"path", tool_call.arguments.get("resource")
|
||||
),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@@ -302,7 +304,9 @@ class MCPSafetyWrapper:
|
||||
tool_lower = tool_name.lower()
|
||||
|
||||
# Check destructive patterns
|
||||
if any(d in tool_lower for d in ["write", "create", "delete", "remove", "update"]):
|
||||
if any(
|
||||
d in tool_lower for d in ["write", "create", "delete", "remove", "update"]
|
||||
):
|
||||
if "file" in tool_lower:
|
||||
if "delete" in tool_lower or "remove" in tool_lower:
|
||||
return ActionType.FILE_DELETE
|
||||
|
||||
@@ -69,7 +69,18 @@ class SafetyMetrics:
|
||||
|
||||
def _init_histogram_buckets(self) -> None:
|
||||
"""Initialize histogram buckets for latency metrics."""
|
||||
latency_buckets = [0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, float("inf")]
|
||||
latency_buckets = [
|
||||
0.01,
|
||||
0.05,
|
||||
0.1,
|
||||
0.25,
|
||||
0.5,
|
||||
1.0,
|
||||
2.5,
|
||||
5.0,
|
||||
10.0,
|
||||
float("inf"),
|
||||
]
|
||||
|
||||
for name in [
|
||||
"validation_latency_seconds",
|
||||
@@ -321,7 +332,8 @@ class SafetyMetrics:
|
||||
async with self._lock:
|
||||
total_validations = sum(self._counters["safety_validations_total"].values())
|
||||
denied_validations = sum(
|
||||
v for k, v in self._counters["safety_validations_total"].items()
|
||||
v
|
||||
for k, v in self._counters["safety_validations_total"].items()
|
||||
if "decision=deny" in k
|
||||
)
|
||||
|
||||
@@ -358,11 +370,13 @@ class SafetyMetrics:
|
||||
"rollbacks_executed": sum(
|
||||
self._counters["safety_rollbacks_total"].values()
|
||||
),
|
||||
"mcp_calls": sum(
|
||||
self._counters["safety_mcp_calls_total"].values()
|
||||
),
|
||||
"pending_approvals": self._gauges.get("safety_pending_approvals", {}).get("", 0),
|
||||
"active_checkpoints": self._gauges.get("safety_active_checkpoints", {}).get("", 0),
|
||||
"mcp_calls": sum(self._counters["safety_mcp_calls_total"].values()),
|
||||
"pending_approvals": self._gauges.get(
|
||||
"safety_pending_approvals", {}
|
||||
).get("", 0),
|
||||
"active_checkpoints": self._gauges.get(
|
||||
"safety_active_checkpoints", {}
|
||||
).get("", 0),
|
||||
}
|
||||
|
||||
async def reset(self) -> None:
|
||||
|
||||
@@ -212,9 +212,7 @@ class ValidationResult(BaseModel):
|
||||
applied_rules: list[str] = Field(
|
||||
default_factory=list, description="IDs of applied rules"
|
||||
)
|
||||
reasons: list[str] = Field(
|
||||
default_factory=list, description="Reasons for decision"
|
||||
)
|
||||
reasons: list[str] = Field(default_factory=list, description="Reasons for decision")
|
||||
approval_id: str | None = Field(None, description="Approval request ID if needed")
|
||||
retry_after_seconds: float | None = Field(
|
||||
None, description="Retry delay if rate limited"
|
||||
@@ -267,9 +265,7 @@ class RateLimitConfig(BaseModel):
|
||||
limit: int = Field(..., description="Maximum allowed in window")
|
||||
window_seconds: int = Field(60, description="Time window in seconds")
|
||||
burst_limit: int | None = Field(None, description="Burst allowance")
|
||||
slowdown_threshold: float = Field(
|
||||
0.8, description="Start slowing at this fraction"
|
||||
)
|
||||
slowdown_threshold: float = Field(0.8, description="Start slowing at this fraction")
|
||||
|
||||
|
||||
class RateLimitStatus(BaseModel):
|
||||
|
||||
@@ -66,9 +66,7 @@ class RollbackManager:
|
||||
"""
|
||||
config = get_safety_config()
|
||||
|
||||
self._checkpoint_dir = Path(
|
||||
checkpoint_dir or config.checkpoint_dir
|
||||
)
|
||||
self._checkpoint_dir = Path(checkpoint_dir or config.checkpoint_dir)
|
||||
self._retention_hours = retention_hours or config.checkpoint_retention_hours
|
||||
|
||||
self._checkpoints: dict[str, Checkpoint] = {}
|
||||
@@ -231,7 +229,9 @@ class RollbackManager:
|
||||
success=success,
|
||||
actions_rolled_back=actions_rolled_back,
|
||||
failed_actions=failed_actions,
|
||||
error=None if success else f"Failed to rollback {len(failed_actions)} items",
|
||||
error=None
|
||||
if success
|
||||
else f"Failed to rollback {len(failed_actions)} items",
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -294,8 +294,7 @@ class RollbackManager:
|
||||
|
||||
if not include_expired:
|
||||
checkpoints = [
|
||||
c for c in checkpoints
|
||||
if c.expires_at is None or c.expires_at > now
|
||||
c for c in checkpoints if c.expires_at is None or c.expires_at > now
|
||||
]
|
||||
|
||||
return checkpoints
|
||||
|
||||
@@ -113,7 +113,9 @@ class ActionValidator:
|
||||
self._rules.append(rule)
|
||||
# Re-sort by priority (higher first)
|
||||
self._rules.sort(key=lambda r: r.priority, reverse=True)
|
||||
logger.debug("Added validation rule: %s (priority %d)", rule.name, rule.priority)
|
||||
logger.debug(
|
||||
"Added validation rule: %s (priority %d)", rule.name, rule.priority
|
||||
)
|
||||
|
||||
def remove_rule(self, rule_id: str) -> bool:
|
||||
"""
|
||||
|
||||
@@ -44,8 +44,8 @@ def mock_superuser():
|
||||
@pytest.fixture
|
||||
def client(mock_mcp_client, mock_superuser):
|
||||
"""Create a FastAPI test client with mocked dependencies."""
|
||||
from app.api.routes.mcp import get_mcp_client
|
||||
from app.api.dependencies.permissions import require_superuser
|
||||
from app.api.routes.mcp import get_mcp_client
|
||||
|
||||
# Override dependencies
|
||||
async def override_get_mcp_client():
|
||||
|
||||
@@ -14,7 +14,6 @@ from app.services.mcp.client_manager import (
|
||||
shutdown_mcp_client,
|
||||
)
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||
from app.services.mcp.connection import ConnectionState
|
||||
from app.services.mcp.exceptions import MCPServerNotFoundError
|
||||
from app.services.mcp.registry import MCPServerRegistry
|
||||
from app.services.mcp.routing import ToolInfo, ToolResult
|
||||
|
||||
@@ -4,10 +4,8 @@ Tests for MCP Configuration System
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from app.services.mcp.config import (
|
||||
MCPConfig,
|
||||
@@ -217,9 +215,7 @@ mcp_servers:
|
||||
default_timeout: 60
|
||||
connection_pool_size: 20
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
@@ -248,9 +244,7 @@ mcp_servers:
|
||||
explicit-server:
|
||||
url: http://explicit:8000
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
@@ -267,9 +261,7 @@ mcp_servers:
|
||||
env-server:
|
||||
url: http://env:8000
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".yaml", delete=False
|
||||
) as f:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f:
|
||||
f.write(yaml_content)
|
||||
f.flush()
|
||||
|
||||
|
||||
@@ -220,9 +220,7 @@ class TestMCPConnection:
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
await conn.connect()
|
||||
result = await conn.execute_request(
|
||||
"POST", "/mcp", data={"method": "test"}
|
||||
)
|
||||
result = await conn.execute_request("POST", "/mcp", data={"method": "test"})
|
||||
|
||||
assert result == {"result": "success"}
|
||||
|
||||
|
||||
@@ -160,11 +160,21 @@ class TestMCPToolNotFoundError:
|
||||
"""Test tool not found with available tools listed."""
|
||||
error = MCPToolNotFoundError(
|
||||
"unknown-tool",
|
||||
available_tools=["tool-1", "tool-2", "tool-3", "tool-4", "tool-5", "tool-6"],
|
||||
available_tools=[
|
||||
"tool-1",
|
||||
"tool-2",
|
||||
"tool-3",
|
||||
"tool-4",
|
||||
"tool-5",
|
||||
"tool-6",
|
||||
],
|
||||
)
|
||||
assert len(error.available_tools) == 6
|
||||
# Should show first 5 tools with ellipsis
|
||||
assert "available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..." in str(error)
|
||||
assert (
|
||||
"available_tools=['tool-1', 'tool-2', 'tool-3', 'tool-4', 'tool-5']..."
|
||||
in str(error)
|
||||
)
|
||||
|
||||
|
||||
class TestMCPCircuitOpenError:
|
||||
|
||||
@@ -4,7 +4,7 @@ Tests for MCP Server Registry
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig, TransportType
|
||||
from app.services.mcp.config import MCPConfig, MCPServerConfig
|
||||
from app.services.mcp.exceptions import MCPServerNotFoundError
|
||||
from app.services.mcp.registry import (
|
||||
MCPServerRegistry,
|
||||
|
||||
@@ -220,7 +220,9 @@ class TestScan:
|
||||
filter_all: ContentFilter,
|
||||
) -> None:
|
||||
"""Test scanning for specific categories only."""
|
||||
content = "Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
|
||||
content = (
|
||||
"Email: test@example.com, token: ghp_abc123456789012345678901234567890123"
|
||||
)
|
||||
|
||||
# Scan only for secrets
|
||||
matches = await filter_all.scan(
|
||||
|
||||
@@ -321,8 +321,7 @@ class TestLoadRulesFromPolicy:
|
||||
validator.load_rules_from_policy(policy)
|
||||
|
||||
approval_rules = [
|
||||
r for r in validator._rules
|
||||
if r.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
r for r in validator._rules if r.decision == SafetyDecision.REQUIRE_APPROVAL
|
||||
]
|
||||
assert len(approval_rules) == 1
|
||||
|
||||
|
||||
@@ -162,7 +162,9 @@ export default function ProjectSettingsPage({ params }: ProjectSettingsPageProps
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Autonomy Level</CardTitle>
|
||||
<CardDescription>Control how much oversight you want over agent actions</CardDescription>
|
||||
<CardDescription>
|
||||
Control how much oversight you want over agent actions
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-4">
|
||||
<div className="space-y-2">
|
||||
|
||||
Reference in New Issue
Block a user