forked from cardosofelipe/fast-next-template
fix(mcp-gateway): address critical issues from deep review
Frontend: - Fix debounce race condition in UserListTable search handler - Use useRef to properly track and cleanup timeout between keystrokes Backend (LLM Gateway): - Add thread-safe double-checked locking for global singletons (providers, circuit registry, cost tracker) - Fix Redis URL parsing with proper urlparse validation - Add explicit error handling for malformed Redis URLs - Document circuit breaker state transition safety 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
|
||||
'use client';
|
||||
|
||||
import { useState, useCallback } from 'react';
|
||||
import { useState, useCallback, useRef, useEffect } from 'react';
|
||||
import { format } from 'date-fns';
|
||||
import { Check, X } from 'lucide-react';
|
||||
import {
|
||||
@@ -61,15 +61,28 @@ export function UserListTable({
|
||||
currentUserId,
|
||||
}: UserListTableProps) {
|
||||
const [searchValue, setSearchValue] = useState('');
|
||||
const searchTimeoutRef = useRef<ReturnType<typeof setTimeout> | null>(null);
|
||||
|
||||
// Debounce search
|
||||
// Cleanup timeout on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (searchTimeoutRef.current) {
|
||||
clearTimeout(searchTimeoutRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Debounce search with proper cleanup
|
||||
const handleSearchChange = useCallback(
|
||||
(value: string) => {
|
||||
setSearchValue(value);
|
||||
const timeoutId = setTimeout(() => {
|
||||
// Clear previous timeout to prevent stale searches
|
||||
if (searchTimeoutRef.current) {
|
||||
clearTimeout(searchTimeoutRef.current);
|
||||
}
|
||||
searchTimeoutRef.current = setTimeout(() => {
|
||||
onSearch(value);
|
||||
}, 300);
|
||||
return () => clearTimeout(timeoutId);
|
||||
},
|
||||
[onSearch]
|
||||
);
|
||||
|
||||
@@ -6,6 +6,7 @@ Provides aggregation by hour, day, and month with TTL-based expiry.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
@@ -441,27 +442,37 @@ def calculate_cost(
|
||||
return round(input_cost + output_cost, 6)
|
||||
|
||||
|
||||
# Global tracker instance (lazy initialization)
|
||||
# Global tracker instance with thread-safe lazy initialization
|
||||
_tracker: CostTracker | None = None
|
||||
_tracker_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_cost_tracker() -> CostTracker:
|
||||
"""Get the global cost tracker instance."""
|
||||
"""
|
||||
Get the global cost tracker instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
"""
|
||||
global _tracker
|
||||
if _tracker is None:
|
||||
_tracker = CostTracker()
|
||||
with _tracker_lock:
|
||||
# Double-check after acquiring lock
|
||||
if _tracker is None:
|
||||
_tracker = CostTracker()
|
||||
return _tracker
|
||||
|
||||
|
||||
async def close_cost_tracker() -> None:
|
||||
"""Close the global cost tracker."""
|
||||
global _tracker
|
||||
if _tracker:
|
||||
await _tracker.close()
|
||||
_tracker = None
|
||||
with _tracker_lock:
|
||||
if _tracker:
|
||||
await _tracker.close()
|
||||
_tracker = None
|
||||
|
||||
|
||||
def reset_cost_tracker() -> None:
|
||||
"""Reset the global tracker (for testing)."""
|
||||
global _tracker
|
||||
_tracker = None
|
||||
with _tracker_lock:
|
||||
_tracker = None
|
||||
|
||||
@@ -7,6 +7,7 @@ temporarily disabling providers that are experiencing issues.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
@@ -85,6 +86,9 @@ class CircuitBreaker:
|
||||
@property
|
||||
def state(self) -> CircuitState:
|
||||
"""Get current circuit state (may trigger state transition)."""
|
||||
# Check transition outside lock since it only reads/writes _state
|
||||
# which is atomic in Python, and we use the lock inside _check_state_transition
|
||||
# if a state change is needed
|
||||
self._check_state_transition()
|
||||
return self._state
|
||||
|
||||
@@ -94,15 +98,26 @@ class CircuitBreaker:
|
||||
return self._stats
|
||||
|
||||
def _check_state_transition(self) -> None:
|
||||
"""Check if state should transition based on time."""
|
||||
"""
|
||||
Check if state should transition based on time.
|
||||
|
||||
Note: This method is intentionally not async and doesn't acquire the lock
|
||||
because it's called frequently from the state property. The transition
|
||||
is safe because:
|
||||
1. We check the current state first (atomic read)
|
||||
2. _transition_to only modifies state if we're still in OPEN state
|
||||
3. Multiple concurrent transitions to HALF_OPEN are idempotent
|
||||
"""
|
||||
if self._state == CircuitState.OPEN:
|
||||
time_in_open = time.time() - self._stats.state_changed_at
|
||||
if time_in_open >= self.recovery_timeout:
|
||||
self._transition_to(CircuitState.HALF_OPEN)
|
||||
logger.info(
|
||||
f"Circuit {self.name} transitioned to HALF_OPEN "
|
||||
f"after {time_in_open:.1f}s"
|
||||
)
|
||||
# Only transition if still in OPEN state (double-check)
|
||||
if self._state == CircuitState.OPEN:
|
||||
self._transition_to(CircuitState.HALF_OPEN)
|
||||
logger.info(
|
||||
f"Circuit {self.name} transitioned to HALF_OPEN "
|
||||
f"after {time_in_open:.1f}s"
|
||||
)
|
||||
|
||||
def _transition_to(self, new_state: CircuitState) -> None:
|
||||
"""Transition to a new state."""
|
||||
@@ -339,19 +354,28 @@ class CircuitBreakerRegistry:
|
||||
]
|
||||
|
||||
|
||||
# Global registry instance (lazy initialization)
|
||||
# Global registry instance with thread-safe lazy initialization
|
||||
_registry: CircuitBreakerRegistry | None = None
|
||||
_registry_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_circuit_registry() -> CircuitBreakerRegistry:
|
||||
"""Get the global circuit breaker registry."""
|
||||
"""
|
||||
Get the global circuit breaker registry.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
"""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = CircuitBreakerRegistry()
|
||||
with _registry_lock:
|
||||
# Double-check after acquiring lock
|
||||
if _registry is None:
|
||||
_registry = CircuitBreakerRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
def reset_circuit_registry() -> None:
|
||||
"""Reset the global registry (for testing)."""
|
||||
global _registry
|
||||
_registry = None
|
||||
with _registry_lock:
|
||||
_registry = None
|
||||
|
||||
@@ -6,7 +6,9 @@ Configures the LiteLLM Router with model lists and failover chains.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import litellm
|
||||
from litellm import Router
|
||||
@@ -57,19 +59,47 @@ def configure_litellm(settings: Settings) -> None:
|
||||
|
||||
|
||||
def _parse_redis_host(redis_url: str) -> str:
|
||||
"""Extract host from Redis URL."""
|
||||
# redis://host:port/db
|
||||
url = redis_url.replace("redis://", "")
|
||||
host_port = url.split("/")[0]
|
||||
return host_port.split(":")[0]
|
||||
"""
|
||||
Extract host from Redis URL.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL (e.g., redis://localhost:6379/0)
|
||||
|
||||
Returns:
|
||||
Hostname extracted from URL
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is malformed or missing host
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(redis_url)
|
||||
if not parsed.hostname:
|
||||
raise ValueError(f"Invalid Redis URL: missing hostname in '{redis_url}'")
|
||||
return parsed.hostname
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse Redis URL: {e}")
|
||||
raise ValueError(f"Invalid Redis URL: {redis_url}") from e
|
||||
|
||||
|
||||
def _parse_redis_port(redis_url: str) -> int:
|
||||
"""Extract port from Redis URL."""
|
||||
url = redis_url.replace("redis://", "")
|
||||
host_port = url.split("/")[0]
|
||||
parts = host_port.split(":")
|
||||
return int(parts[1]) if len(parts) > 1 else 6379
|
||||
"""
|
||||
Extract port from Redis URL.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL (e.g., redis://localhost:6379/0)
|
||||
|
||||
Returns:
|
||||
Port number (defaults to 6379 if not specified)
|
||||
|
||||
Raises:
|
||||
ValueError: If URL is malformed
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(redis_url)
|
||||
return parsed.port or 6379
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse Redis URL: {e}")
|
||||
raise ValueError(f"Invalid Redis URL: {redis_url}") from e
|
||||
|
||||
|
||||
def _is_provider_available(provider: Provider, settings: Settings) -> bool:
|
||||
@@ -310,19 +340,28 @@ class LLMProvider:
|
||||
return _is_provider_available(model_config.provider, self._settings)
|
||||
|
||||
|
||||
# Global provider instance (lazy initialization)
|
||||
# Global provider instance with thread-safe lazy initialization
|
||||
_provider: LLMProvider | None = None
|
||||
_provider_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_provider() -> LLMProvider:
|
||||
"""Get the global LLM Provider instance."""
|
||||
"""
|
||||
Get the global LLM Provider instance.
|
||||
|
||||
Thread-safe with double-checked locking pattern.
|
||||
"""
|
||||
global _provider
|
||||
if _provider is None:
|
||||
_provider = LLMProvider()
|
||||
with _provider_lock:
|
||||
# Double-check after acquiring lock
|
||||
if _provider is None:
|
||||
_provider = LLMProvider()
|
||||
return _provider
|
||||
|
||||
|
||||
def reset_provider() -> None:
|
||||
"""Reset the global provider (for testing)."""
|
||||
global _provider
|
||||
_provider = None
|
||||
with _provider_lock:
|
||||
_provider = None
|
||||
|
||||
Reference in New Issue
Block a user