feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking

- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation.
- Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming.
- Refactored `CostController`'s session and daily budget alert handling for improved clarity.
- Implemented test suites for `CostController` and `SafetyGuardian` to validate changes.
- Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
This commit is contained in:
2026-01-03 17:55:34 +01:00
parent 520c06175e
commit caf283bed2
9 changed files with 1782 additions and 92 deletions

View File

@@ -411,7 +411,20 @@ async def shutdown_mcp_client() -> None:
_manager_instance = None
def reset_mcp_client() -> None:
"""Reset the global MCP client manager (for testing)."""
async def reset_mcp_client() -> None:
"""
Reset the global MCP client manager (for testing).
This is an async function to properly acquire the manager lock
and avoid race conditions with get_mcp_client().
"""
global _manager_instance
_manager_instance = None
async with _manager_lock:
if _manager_instance is not None:
# Shutdown gracefully before resetting
try:
await _manager_instance.shutdown()
except Exception: # noqa: S110
pass # Ignore errors during test cleanup
_manager_instance = None

View File

@@ -161,7 +161,7 @@ class MCPConnection:
server_name=self.server_name,
url=self.config.url,
cause=e,
)
) from e
else:
# For STDIO and SSE transports, we'll implement later
raise NotImplementedError(
@@ -297,13 +297,13 @@ class MCPConnection:
server_name=self.server_name,
url=f"{self.config.url}{path}",
cause=e,
)
) from e
except Exception as e:
raise MCPConnectionError(
f"Request failed: {e}",
server_name=self.server_name,
cause=e,
)
) from e
class ConnectionPool:
@@ -322,8 +322,19 @@ class ConnectionPool:
"""
self._connections: dict[str, MCPConnection] = {}
self._lock = asyncio.Lock()
self._per_server_locks: dict[str, asyncio.Lock] = {}
self._max_per_server = max_connections_per_server
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
"""Get or create a lock for a specific server.
Uses setdefault for atomic dict access to prevent race conditions
where two coroutines could create different locks for the same server.
"""
# setdefault is atomic - if key exists, returns existing value
# if key doesn't exist, inserts new value and returns it
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
async def get_connection(
self,
server_name: str,
@@ -332,6 +343,9 @@ class ConnectionPool:
"""
Get or create a connection to a server.
Uses per-server locking to avoid blocking all connections
when establishing a new connection.
Args:
server_name: Name of the server
config: Server configuration
@@ -339,17 +353,33 @@ class ConnectionPool:
Returns:
Active connection
"""
async with self._lock:
if server_name not in self._connections:
connection = MCPConnection(server_name, config)
await connection.connect()
self._connections[server_name] = connection
# Quick check without lock - if connection exists and is connected, return it
if server_name in self._connections:
connection = self._connections[server_name]
if connection.is_connected:
return connection
# Reconnect if not connected
if not connection.is_connected:
# Need to create or reconnect - use per-server lock to avoid blocking others
async with self._lock:
server_lock = self._get_server_lock(server_name)
async with server_lock:
# Double-check after acquiring per-server lock
if server_name in self._connections:
connection = self._connections[server_name]
if connection.is_connected:
return connection
# Connection exists but not connected - reconnect
await connection.connect()
return connection
# Create new connection (outside global lock, under per-server lock)
connection = MCPConnection(server_name, config)
await connection.connect()
# Store connection under global lock
async with self._lock:
self._connections[server_name] = connection
return connection
@@ -374,6 +404,9 @@ class ConnectionPool:
if server_name in self._connections:
await self._connections[server_name].disconnect()
del self._connections[server_name]
# Clean up per-server lock
if server_name in self._per_server_locks:
del self._per_server_locks[server_name]
async def close_all(self) -> None:
"""Close all connections in the pool."""
@@ -385,6 +418,7 @@ class ConnectionPool:
logger.warning("Error closing connection: %s", e)
self._connections.clear()
self._per_server_locks.clear()
logger.info("Closed all MCP connections")
async def health_check_all(self) -> dict[str, bool]:
@@ -394,8 +428,12 @@ class ConnectionPool:
Returns:
Dict mapping server names to health status
"""
# Copy connections under lock to prevent modification during iteration
async with self._lock:
connections_snapshot = dict(self._connections)
results = {}
for name, connection in self._connections.items():
for name, connection in connections_snapshot.items():
results[name] = await connection.health_check()
return results