forked from cardosofelipe/fast-next-template
Improved code readability and uniformity by standardizing line breaks, indentation, and inline conditions across safety-related services, models, and tests, including content filters, validation rules, and emergency controls.
270 lines
8.3 KiB
Python
270 lines
8.3 KiB
Python
"""
|
|
Loop Detector
|
|
|
|
Detects and prevents action loops in agent behavior.
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from collections import Counter, deque
|
|
from typing import Any
|
|
|
|
from ..config import get_safety_config
|
|
from ..exceptions import LoopDetectedError
|
|
from ..models import ActionRequest
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ActionSignature:
|
|
"""Signature of an action for comparison."""
|
|
|
|
def __init__(self, action: ActionRequest) -> None:
|
|
self.action_type = action.action_type.value
|
|
self.tool_name = action.tool_name
|
|
self.resource = action.resource
|
|
self.args_hash = self._hash_args(action.arguments)
|
|
|
|
def _hash_args(self, args: dict[str, Any]) -> str:
|
|
"""Create a hash of the arguments."""
|
|
try:
|
|
serialized = json.dumps(args, sort_keys=True, default=str)
|
|
return hashlib.sha256(serialized.encode()).hexdigest()[:8]
|
|
except Exception:
|
|
return ""
|
|
|
|
def exact_key(self) -> str:
|
|
"""Key for exact match detection."""
|
|
return f"{self.action_type}:{self.tool_name}:{self.resource}:{self.args_hash}"
|
|
|
|
def semantic_key(self) -> str:
|
|
"""Key for semantic (similar) match detection."""
|
|
return f"{self.action_type}:{self.tool_name}:{self.resource}"
|
|
|
|
def type_key(self) -> str:
|
|
"""Key for action type only."""
|
|
return f"{self.action_type}"
|
|
|
|
|
|
class LoopDetector:
|
|
"""
|
|
Detects action loops and repetitive behavior.
|
|
|
|
Loop Types:
|
|
- Exact: Same action with same arguments
|
|
- Semantic: Similar actions (same type/tool/resource, different args)
|
|
- Oscillation: A→B→A→B patterns
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
history_size: int | None = None,
|
|
max_exact_repetitions: int | None = None,
|
|
max_semantic_repetitions: int | None = None,
|
|
) -> None:
|
|
"""
|
|
Initialize the LoopDetector.
|
|
|
|
Args:
|
|
history_size: Size of action history to track
|
|
max_exact_repetitions: Max allowed exact repetitions
|
|
max_semantic_repetitions: Max allowed semantic repetitions
|
|
"""
|
|
config = get_safety_config()
|
|
|
|
self._history_size = history_size or config.loop_history_size
|
|
self._max_exact = max_exact_repetitions or config.max_repeated_actions
|
|
self._max_semantic = max_semantic_repetitions or config.max_similar_actions
|
|
|
|
# Per-agent history
|
|
self._histories: dict[str, deque[ActionSignature]] = {}
|
|
self._lock = asyncio.Lock()
|
|
|
|
async def check(self, action: ActionRequest) -> tuple[bool, str | None]:
|
|
"""
|
|
Check if an action would create a loop.
|
|
|
|
Args:
|
|
action: The action to check
|
|
|
|
Returns:
|
|
Tuple of (is_loop, loop_type)
|
|
"""
|
|
agent_id = action.metadata.agent_id
|
|
signature = ActionSignature(action)
|
|
|
|
async with self._lock:
|
|
history = self._get_history(agent_id)
|
|
|
|
# Check exact repetition
|
|
exact_key = signature.exact_key()
|
|
exact_count = sum(1 for h in history if h.exact_key() == exact_key)
|
|
if exact_count >= self._max_exact:
|
|
return True, "exact"
|
|
|
|
# Check semantic repetition
|
|
semantic_key = signature.semantic_key()
|
|
semantic_count = sum(1 for h in history if h.semantic_key() == semantic_key)
|
|
if semantic_count >= self._max_semantic:
|
|
return True, "semantic"
|
|
|
|
# Check oscillation (A→B→A→B pattern)
|
|
if len(history) >= 3:
|
|
pattern = self._detect_oscillation(history, signature)
|
|
if pattern:
|
|
return True, "oscillation"
|
|
|
|
return False, None
|
|
|
|
async def check_and_raise(self, action: ActionRequest) -> None:
|
|
"""
|
|
Check for loops and raise if detected.
|
|
|
|
Args:
|
|
action: The action to check
|
|
|
|
Raises:
|
|
LoopDetectedError: If loop is detected
|
|
"""
|
|
is_loop, loop_type = await self.check(action)
|
|
if is_loop:
|
|
signature = ActionSignature(action)
|
|
raise LoopDetectedError(
|
|
f"Loop detected: {loop_type}",
|
|
loop_type=loop_type or "unknown",
|
|
repetition_count=self._max_exact
|
|
if loop_type == "exact"
|
|
else self._max_semantic,
|
|
action_pattern=[signature.semantic_key()],
|
|
agent_id=action.metadata.agent_id,
|
|
action_id=action.id,
|
|
)
|
|
|
|
async def record(self, action: ActionRequest) -> None:
|
|
"""
|
|
Record an action in history.
|
|
|
|
Args:
|
|
action: The action to record
|
|
"""
|
|
agent_id = action.metadata.agent_id
|
|
signature = ActionSignature(action)
|
|
|
|
async with self._lock:
|
|
history = self._get_history(agent_id)
|
|
history.append(signature)
|
|
|
|
async def clear_history(self, agent_id: str) -> None:
|
|
"""
|
|
Clear history for an agent.
|
|
|
|
Args:
|
|
agent_id: ID of the agent
|
|
"""
|
|
async with self._lock:
|
|
if agent_id in self._histories:
|
|
self._histories[agent_id].clear()
|
|
|
|
async def get_stats(self, agent_id: str) -> dict[str, Any]:
|
|
"""
|
|
Get loop detection stats for an agent.
|
|
|
|
Args:
|
|
agent_id: ID of the agent
|
|
|
|
Returns:
|
|
Stats dictionary
|
|
"""
|
|
async with self._lock:
|
|
history = self._get_history(agent_id)
|
|
|
|
# Count action types
|
|
type_counts = Counter(h.type_key() for h in history)
|
|
semantic_counts = Counter(h.semantic_key() for h in history)
|
|
|
|
return {
|
|
"history_size": len(history),
|
|
"max_history": self._history_size,
|
|
"action_type_counts": dict(type_counts),
|
|
"top_semantic_patterns": semantic_counts.most_common(5),
|
|
}
|
|
|
|
def _get_history(self, agent_id: str) -> deque[ActionSignature]:
|
|
"""Get or create history for an agent."""
|
|
if agent_id not in self._histories:
|
|
self._histories[agent_id] = deque(maxlen=self._history_size)
|
|
return self._histories[agent_id]
|
|
|
|
def _detect_oscillation(
|
|
self,
|
|
history: deque[ActionSignature],
|
|
current: ActionSignature,
|
|
) -> bool:
|
|
"""
|
|
Detect A→B→A→B oscillation pattern.
|
|
|
|
Looks at last 4+ actions including current.
|
|
"""
|
|
if len(history) < 3:
|
|
return False
|
|
|
|
# Get last 3 actions + current
|
|
recent = [*list(history)[-3:], current]
|
|
|
|
# Check for A→B→A→B pattern
|
|
if len(recent) >= 4:
|
|
# Get semantic keys
|
|
keys = [a.semantic_key() for a in recent[-4:]]
|
|
|
|
# Pattern: k[0]==k[2] and k[1]==k[3] and k[0]!=k[1]
|
|
if keys[0] == keys[2] and keys[1] == keys[3] and keys[0] != keys[1]:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
class LoopBreaker:
|
|
"""
|
|
Strategies for breaking detected loops.
|
|
"""
|
|
|
|
@staticmethod
|
|
async def suggest_alternatives(
|
|
action: ActionRequest,
|
|
loop_type: str,
|
|
) -> list[str]:
|
|
"""
|
|
Suggest alternative actions when loop is detected.
|
|
|
|
Args:
|
|
action: The looping action
|
|
loop_type: Type of loop detected
|
|
|
|
Returns:
|
|
List of suggestions
|
|
"""
|
|
suggestions = []
|
|
|
|
if loop_type == "exact":
|
|
suggestions.append(
|
|
"The same action with identical arguments has been repeated too many times. "
|
|
"Consider: (1) Verify the action succeeded, (2) Try a different approach, "
|
|
"(3) Escalate for human review"
|
|
)
|
|
elif loop_type == "semantic":
|
|
suggestions.append(
|
|
"Similar actions have been repeated too many times. "
|
|
"Consider: (1) Review if the approach is working, (2) Try an alternative method, "
|
|
"(3) Request clarification on the goal"
|
|
)
|
|
elif loop_type == "oscillation":
|
|
suggestions.append(
|
|
"An oscillating pattern was detected (A→B→A→B). "
|
|
"This usually indicates conflicting goals or a stuck state. "
|
|
"Consider: (1) Step back and reassess, (2) Request human guidance"
|
|
)
|
|
|
|
return suggestions
|