test(safety): add comprehensive tests for safety framework modules

Add tests to improve backend coverage from 85% to 93%:

- test_audit.py: 60 tests for AuditLogger (20% -> 99%)
  - Hash chain integrity, sanitization, retention, handlers
  - Fixed bug: hash chain modification after event creation
  - Fixed bug: verification not using correct prev_hash

- test_hitl.py: Tests for HITL manager (0% -> 100%)
- test_permissions.py: Tests for permissions manager (0% -> 99%)
- test_rollback.py: Tests for rollback manager (0% -> 100%)
- test_metrics.py: Tests for metrics collector (0% -> 100%)
- test_mcp_integration.py: Tests for MCP safety wrapper (0% -> 100%)
- test_validation.py: Additional cache and edge case tests (76% -> 100%)
- test_scoring.py: Lock cleanup and edge case tests (78% -> 91%)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-04 19:41:54 +01:00
parent 758052dcff
commit 60ebeaa582
10 changed files with 6025 additions and 9 deletions

View File

@@ -123,7 +123,9 @@ class ClaudeAdapter(ModelAdapter):
if score:
# Escape score to prevent XML injection via metadata
escaped_score = self._escape_xml(str(score))
parts.append(f'<document source="{source}" relevance="{escaped_score}">')
parts.append(
f'<document source="{source}" relevance="{escaped_score}">'
)
else:
parts.append(f'<document source="{source}">')

View File

@@ -24,6 +24,9 @@ from ..models import (
logger = logging.getLogger(__name__)
# Sentinel for distinguishing "no argument passed" from "explicitly passing None"
_UNSET = object()
class AuditLogger:
"""
@@ -142,8 +145,10 @@ class AuditLogger:
# Add hash chain for tamper detection
if self._enable_hash_chain:
event_hash = self._compute_hash(event)
sanitized_details["_hash"] = event_hash
sanitized_details["_prev_hash"] = self._last_hash
# Modify event.details directly (not sanitized_details)
# to ensure the hash is stored on the actual event
event.details["_hash"] = event_hash
event.details["_prev_hash"] = self._last_hash
self._last_hash = event_hash
self._buffer.append(event)
@@ -415,7 +420,8 @@ class AuditLogger:
)
if stored_hash:
computed = self._compute_hash(event)
# Pass prev_hash to compute hash with correct chain position
computed = self._compute_hash(event, prev_hash=prev_hash)
if computed != stored_hash:
issues.append(
f"Hash mismatch at event {event.id}: "
@@ -462,9 +468,23 @@ class AuditLogger:
return sanitized
def _compute_hash(self, event: AuditEvent) -> str:
"""Compute hash for an event (excluding hash fields)."""
data = {
def _compute_hash(
self, event: AuditEvent, prev_hash: str | None | object = _UNSET
) -> str:
"""Compute hash for an event (excluding hash fields).
Args:
event: The audit event to hash.
prev_hash: Optional previous hash to use instead of self._last_hash.
Pass this during verification to use the correct chain.
Use None explicitly to indicate no previous hash.
"""
# Use passed prev_hash if explicitly provided, otherwise use instance state
effective_prev: str | None = (
self._last_hash if prev_hash is _UNSET else prev_hash # type: ignore[assignment]
)
data: dict[str, str | dict[str, str] | None] = {
"id": event.id,
"event_type": event.event_type.value,
"timestamp": event.timestamp.isoformat(),
@@ -480,8 +500,8 @@ class AuditLogger:
"correlation_id": event.correlation_id,
}
if self._last_hash:
data["_prev_hash"] = self._last_hash
if effective_prev:
data["_prev_hash"] = effective_prev
serialized = json.dumps(data, sort_keys=True, default=str)
return hashlib.sha256(serialized.encode()).hexdigest()