forked from cardosofelipe/pragma-stack
Compare commits
2 Commits
da85a8aba8
...
35aea2d73a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
35aea2d73a | ||
|
|
d0f32d04f7 |
@@ -18,7 +18,10 @@ from sqlalchemy import (
|
|||||||
Text,
|
Text,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import JSONB, UUID as PGUUID
|
from sqlalchemy.dialects.postgresql import (
|
||||||
|
JSONB,
|
||||||
|
UUID as PGUUID,
|
||||||
|
)
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||||
|
|||||||
@@ -122,16 +122,24 @@ class MCPClientManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _connect_all_servers(self) -> None:
|
async def _connect_all_servers(self) -> None:
|
||||||
"""Connect to all enabled MCP servers."""
|
"""Connect to all enabled MCP servers concurrently."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
enabled_servers = self._registry.get_enabled_configs()
|
enabled_servers = self._registry.get_enabled_configs()
|
||||||
|
|
||||||
for name, config in enabled_servers.items():
|
async def connect_server(name: str, config: "MCPServerConfig") -> None:
|
||||||
try:
|
try:
|
||||||
await self._pool.get_connection(name, config)
|
await self._pool.get_connection(name, config)
|
||||||
logger.info("Connected to MCP server: %s", name)
|
logger.info("Connected to MCP server: %s", name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
logger.error("Failed to connect to MCP server %s: %s", name, e)
|
||||||
|
|
||||||
|
# Connect to all servers concurrently for faster startup
|
||||||
|
await asyncio.gather(
|
||||||
|
*(connect_server(name, config) for name, config in enabled_servers.items()),
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
"""
|
"""
|
||||||
Shutdown the MCP client manager.
|
Shutdown the MCP client manager.
|
||||||
|
|||||||
@@ -179,6 +179,8 @@ def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
|||||||
2. MCP_CONFIG_PATH environment variable
|
2. MCP_CONFIG_PATH environment variable
|
||||||
3. Default path (backend/mcp_servers.yaml)
|
3. Default path (backend/mcp_servers.yaml)
|
||||||
4. Empty config if no file exists
|
4. Empty config if no file exists
|
||||||
|
|
||||||
|
In test mode (IS_TEST=True), retry settings are reduced for faster tests.
|
||||||
"""
|
"""
|
||||||
if path is None:
|
if path is None:
|
||||||
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
|
path = os.environ.get("MCP_CONFIG_PATH", str(DEFAULT_CONFIG_PATH))
|
||||||
@@ -189,7 +191,18 @@ def load_mcp_config(path: str | Path | None = None) -> MCPConfig:
|
|||||||
# Return empty config if no file exists (allows runtime registration)
|
# Return empty config if no file exists (allows runtime registration)
|
||||||
return MCPConfig()
|
return MCPConfig()
|
||||||
|
|
||||||
return MCPConfig.from_yaml(path)
|
config = MCPConfig.from_yaml(path)
|
||||||
|
|
||||||
|
# In test mode, reduce retry settings to speed up tests
|
||||||
|
is_test = os.environ.get("IS_TEST", "").lower() in ("true", "1", "yes")
|
||||||
|
if is_test:
|
||||||
|
for server_config in config.mcp_servers.values():
|
||||||
|
server_config.retry_attempts = 1 # Single attempt
|
||||||
|
server_config.retry_delay = 0.1 # 100ms instead of 1s
|
||||||
|
server_config.retry_max_delay = 0.5 # 500ms max
|
||||||
|
server_config.timeout = 2 # 2s timeout instead of 30-120s
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
def create_default_config() -> MCPConfig:
|
def create_default_config() -> MCPConfig:
|
||||||
|
|||||||
@@ -188,13 +188,14 @@ class TestPasswordResetConfirm:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_password_reset_confirm_expired_token(self, client, async_test_user):
|
async def test_password_reset_confirm_expired_token(self, client, async_test_user):
|
||||||
"""Test password reset confirmation with expired token."""
|
"""Test password reset confirmation with expired token."""
|
||||||
import time as time_module
|
import asyncio
|
||||||
|
|
||||||
# Create token that expires immediately
|
# Create token that expires at current second (expires_in=0)
|
||||||
token = create_password_reset_token(async_test_user.email, expires_in=1)
|
# Token expires when exp < current_time, so we need to cross a second boundary
|
||||||
|
token = create_password_reset_token(async_test_user.email, expires_in=0)
|
||||||
|
|
||||||
# Wait for token to expire
|
# Wait for token to expire (need to cross second boundary)
|
||||||
time_module.sleep(2)
|
await asyncio.sleep(1.1)
|
||||||
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"/api/v1/auth/password-reset/confirm",
|
"/api/v1/auth/password-reset/confirm",
|
||||||
|
|||||||
@@ -160,11 +160,11 @@ class TestEmbeddingCache:
|
|||||||
|
|
||||||
async def test_ttl_expiration(self) -> None:
|
async def test_ttl_expiration(self) -> None:
|
||||||
"""Should expire entries after TTL."""
|
"""Should expire entries after TTL."""
|
||||||
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.1)
|
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.05)
|
||||||
|
|
||||||
await cache.put("content", [0.1, 0.2])
|
await cache.put("content", [0.1, 0.2])
|
||||||
|
|
||||||
time.sleep(0.2)
|
time.sleep(0.06)
|
||||||
|
|
||||||
result = await cache.get("content")
|
result = await cache.get("content")
|
||||||
|
|
||||||
@@ -226,13 +226,13 @@ class TestEmbeddingCache:
|
|||||||
|
|
||||||
def test_cleanup_expired(self) -> None:
|
def test_cleanup_expired(self) -> None:
|
||||||
"""Should remove expired entries."""
|
"""Should remove expired entries."""
|
||||||
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.1)
|
cache = EmbeddingCache(max_size=100, default_ttl_seconds=0.05)
|
||||||
|
|
||||||
# Use synchronous put for setup
|
# Use synchronous put for setup
|
||||||
cache._put_memory("hash1", "default", [0.1])
|
cache._put_memory("hash1", "default", [0.1])
|
||||||
cache._put_memory("hash2", "default", [0.2], ttl_seconds=10)
|
cache._put_memory("hash2", "default", [0.2], ttl_seconds=10)
|
||||||
|
|
||||||
time.sleep(0.2)
|
time.sleep(0.06)
|
||||||
|
|
||||||
count = cache.cleanup_expired()
|
count = cache.cleanup_expired()
|
||||||
|
|
||||||
|
|||||||
@@ -212,12 +212,12 @@ class TestHotMemoryCache:
|
|||||||
|
|
||||||
def test_ttl_expiration(self) -> None:
|
def test_ttl_expiration(self) -> None:
|
||||||
"""Should expire entries after TTL."""
|
"""Should expire entries after TTL."""
|
||||||
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.1)
|
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.05)
|
||||||
|
|
||||||
cache.put_by_id("test", "1", "value")
|
cache.put_by_id("test", "1", "value")
|
||||||
|
|
||||||
# Wait for expiration
|
# Wait for expiration
|
||||||
time.sleep(0.2)
|
time.sleep(0.06)
|
||||||
|
|
||||||
result = cache.get_by_id("test", "1")
|
result = cache.get_by_id("test", "1")
|
||||||
|
|
||||||
@@ -289,12 +289,12 @@ class TestHotMemoryCache:
|
|||||||
|
|
||||||
def test_cleanup_expired(self) -> None:
|
def test_cleanup_expired(self) -> None:
|
||||||
"""Should remove expired entries."""
|
"""Should remove expired entries."""
|
||||||
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.1)
|
cache = HotMemoryCache[str](max_size=100, default_ttl_seconds=0.05)
|
||||||
|
|
||||||
cache.put_by_id("test", "1", "value1")
|
cache.put_by_id("test", "1", "value1")
|
||||||
cache.put_by_id("test", "2", "value2", ttl_seconds=10)
|
cache.put_by_id("test", "2", "value2", ttl_seconds=10)
|
||||||
|
|
||||||
time.sleep(0.2)
|
time.sleep(0.06)
|
||||||
|
|
||||||
count = cache.cleanup_expired()
|
count = cache.cleanup_expired()
|
||||||
|
|
||||||
|
|||||||
@@ -78,13 +78,13 @@ class TestInMemoryStorageTTL:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ttl_expiration(self, storage: InMemoryStorage) -> None:
|
async def test_ttl_expiration(self, storage: InMemoryStorage) -> None:
|
||||||
"""Test that expired keys return None."""
|
"""Test that expired keys return None."""
|
||||||
await storage.set("key1", "value1", ttl_seconds=1)
|
await storage.set("key1", "value1", ttl_seconds=0.1)
|
||||||
|
|
||||||
# Key exists initially
|
# Key exists initially
|
||||||
assert await storage.get("key1") == "value1"
|
assert await storage.get("key1") == "value1"
|
||||||
|
|
||||||
# Wait for expiration
|
# Wait for expiration
|
||||||
await asyncio.sleep(1.1)
|
await asyncio.sleep(0.15)
|
||||||
|
|
||||||
# Key should be expired
|
# Key should be expired
|
||||||
assert await storage.get("key1") is None
|
assert await storage.get("key1") is None
|
||||||
@@ -93,10 +93,10 @@ class TestInMemoryStorageTTL:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_remove_ttl_on_update(self, storage: InMemoryStorage) -> None:
|
async def test_remove_ttl_on_update(self, storage: InMemoryStorage) -> None:
|
||||||
"""Test that updating without TTL removes expiration."""
|
"""Test that updating without TTL removes expiration."""
|
||||||
await storage.set("key1", "value1", ttl_seconds=1)
|
await storage.set("key1", "value1", ttl_seconds=0.1)
|
||||||
await storage.set("key1", "value2") # No TTL
|
await storage.set("key1", "value2") # No TTL
|
||||||
|
|
||||||
await asyncio.sleep(1.1)
|
await asyncio.sleep(0.15)
|
||||||
|
|
||||||
# Key should still exist (TTL removed)
|
# Key should still exist (TTL removed)
|
||||||
assert await storage.get("key1") == "value2"
|
assert await storage.get("key1") == "value2"
|
||||||
@@ -180,10 +180,10 @@ class TestInMemoryStorageCapacity:
|
|||||||
"""Test that expired keys are cleaned up for capacity."""
|
"""Test that expired keys are cleaned up for capacity."""
|
||||||
storage = InMemoryStorage(max_keys=2)
|
storage = InMemoryStorage(max_keys=2)
|
||||||
|
|
||||||
await storage.set("key1", "value1", ttl_seconds=1)
|
await storage.set("key1", "value1", ttl_seconds=0.1)
|
||||||
await storage.set("key2", "value2")
|
await storage.set("key2", "value2")
|
||||||
|
|
||||||
await asyncio.sleep(1.1)
|
await asyncio.sleep(0.15)
|
||||||
|
|
||||||
# Should succeed because key1 is expired and will be cleaned
|
# Should succeed because key1 is expired and will be cleaned
|
||||||
await storage.set("key3", "value3")
|
await storage.set("key3", "value3")
|
||||||
|
|||||||
Reference in New Issue
Block a user