feat(safety): enhance rate limiting and cost control with alert deduplication and usage tracking
- Added `record_action` in `RateLimiter` for precise tracking of slot consumption post-validation. - Introduced deduplication mechanism for warning alerts in `CostController` to prevent spamming. - Refactored `CostController`'s session and daily budget alert handling for improved clarity. - Implemented test suites for `CostController` and `SafetyGuardian` to validate changes. - Expanded integration testing to cover deduplication, validation, and loop detection edge cases.
This commit is contained in:
@@ -411,7 +411,20 @@ async def shutdown_mcp_client() -> None:
|
||||
_manager_instance = None
|
||||
|
||||
|
||||
def reset_mcp_client() -> None:
|
||||
"""Reset the global MCP client manager (for testing)."""
|
||||
async def reset_mcp_client() -> None:
|
||||
"""
|
||||
Reset the global MCP client manager (for testing).
|
||||
|
||||
This is an async function to properly acquire the manager lock
|
||||
and avoid race conditions with get_mcp_client().
|
||||
"""
|
||||
global _manager_instance
|
||||
_manager_instance = None
|
||||
|
||||
async with _manager_lock:
|
||||
if _manager_instance is not None:
|
||||
# Shutdown gracefully before resetting
|
||||
try:
|
||||
await _manager_instance.shutdown()
|
||||
except Exception: # noqa: S110
|
||||
pass # Ignore errors during test cleanup
|
||||
_manager_instance = None
|
||||
|
||||
@@ -161,7 +161,7 @@ class MCPConnection:
|
||||
server_name=self.server_name,
|
||||
url=self.config.url,
|
||||
cause=e,
|
||||
)
|
||||
) from e
|
||||
else:
|
||||
# For STDIO and SSE transports, we'll implement later
|
||||
raise NotImplementedError(
|
||||
@@ -297,13 +297,13 @@ class MCPConnection:
|
||||
server_name=self.server_name,
|
||||
url=f"{self.config.url}{path}",
|
||||
cause=e,
|
||||
)
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise MCPConnectionError(
|
||||
f"Request failed: {e}",
|
||||
server_name=self.server_name,
|
||||
cause=e,
|
||||
)
|
||||
) from e
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
@@ -322,8 +322,19 @@ class ConnectionPool:
|
||||
"""
|
||||
self._connections: dict[str, MCPConnection] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._per_server_locks: dict[str, asyncio.Lock] = {}
|
||||
self._max_per_server = max_connections_per_server
|
||||
|
||||
def _get_server_lock(self, server_name: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific server.
|
||||
|
||||
Uses setdefault for atomic dict access to prevent race conditions
|
||||
where two coroutines could create different locks for the same server.
|
||||
"""
|
||||
# setdefault is atomic - if key exists, returns existing value
|
||||
# if key doesn't exist, inserts new value and returns it
|
||||
return self._per_server_locks.setdefault(server_name, asyncio.Lock())
|
||||
|
||||
async def get_connection(
|
||||
self,
|
||||
server_name: str,
|
||||
@@ -332,6 +343,9 @@ class ConnectionPool:
|
||||
"""
|
||||
Get or create a connection to a server.
|
||||
|
||||
Uses per-server locking to avoid blocking all connections
|
||||
when establishing a new connection.
|
||||
|
||||
Args:
|
||||
server_name: Name of the server
|
||||
config: Server configuration
|
||||
@@ -339,17 +353,33 @@ class ConnectionPool:
|
||||
Returns:
|
||||
Active connection
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_name not in self._connections:
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
self._connections[server_name] = connection
|
||||
|
||||
# Quick check without lock - if connection exists and is connected, return it
|
||||
if server_name in self._connections:
|
||||
connection = self._connections[server_name]
|
||||
if connection.is_connected:
|
||||
return connection
|
||||
|
||||
# Reconnect if not connected
|
||||
if not connection.is_connected:
|
||||
# Need to create or reconnect - use per-server lock to avoid blocking others
|
||||
async with self._lock:
|
||||
server_lock = self._get_server_lock(server_name)
|
||||
|
||||
async with server_lock:
|
||||
# Double-check after acquiring per-server lock
|
||||
if server_name in self._connections:
|
||||
connection = self._connections[server_name]
|
||||
if connection.is_connected:
|
||||
return connection
|
||||
# Connection exists but not connected - reconnect
|
||||
await connection.connect()
|
||||
return connection
|
||||
|
||||
# Create new connection (outside global lock, under per-server lock)
|
||||
connection = MCPConnection(server_name, config)
|
||||
await connection.connect()
|
||||
|
||||
# Store connection under global lock
|
||||
async with self._lock:
|
||||
self._connections[server_name] = connection
|
||||
|
||||
return connection
|
||||
|
||||
@@ -374,6 +404,9 @@ class ConnectionPool:
|
||||
if server_name in self._connections:
|
||||
await self._connections[server_name].disconnect()
|
||||
del self._connections[server_name]
|
||||
# Clean up per-server lock
|
||||
if server_name in self._per_server_locks:
|
||||
del self._per_server_locks[server_name]
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close all connections in the pool."""
|
||||
@@ -385,6 +418,7 @@ class ConnectionPool:
|
||||
logger.warning("Error closing connection: %s", e)
|
||||
|
||||
self._connections.clear()
|
||||
self._per_server_locks.clear()
|
||||
logger.info("Closed all MCP connections")
|
||||
|
||||
async def health_check_all(self) -> dict[str, bool]:
|
||||
@@ -394,8 +428,12 @@ class ConnectionPool:
|
||||
Returns:
|
||||
Dict mapping server names to health status
|
||||
"""
|
||||
# Copy connections under lock to prevent modification during iteration
|
||||
async with self._lock:
|
||||
connections_snapshot = dict(self._connections)
|
||||
|
||||
results = {}
|
||||
for name, connection in self._connections.items():
|
||||
for name, connection in connections_snapshot.items():
|
||||
results[name] = await connection.health_check()
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user