Reformatted multiline function calls, object definitions, and queries for improved code readability and consistency. Adjusted imports and constraints where necessary.
612 lines
20 KiB
Python
612 lines
20 KiB
Python
"""
|
|
EventBus service for Redis Pub/Sub communication.
|
|
|
|
This module provides a centralized event bus for publishing and subscribing to
|
|
events across the Syndarix platform. It uses Redis Pub/Sub for real-time
|
|
message delivery between services, agents, and the frontend.
|
|
|
|
Architecture:
|
|
- Publishers emit events to project/agent-specific Redis channels
|
|
- SSE endpoints subscribe to channels and stream events to clients
|
|
- Events include metadata for reconnection support (Last-Event-ID)
|
|
- Events are typed with the EventType enum for consistency
|
|
|
|
Usage:
|
|
# Publishing events
|
|
event_bus = EventBus()
|
|
await event_bus.connect()
|
|
|
|
event = event_bus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="agent",
|
|
payload={"message": "Processing..."}
|
|
)
|
|
await event_bus.publish(event_bus.get_project_channel(project_id), event)
|
|
|
|
# Subscribing to events
|
|
async for event in event_bus.subscribe(["project:123", "agent:456"]):
|
|
handle_event(event)
|
|
|
|
# Cleanup
|
|
await event_bus.disconnect()
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from collections.abc import AsyncGenerator, AsyncIterator
|
|
from contextlib import asynccontextmanager
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
from uuid import UUID, uuid4
|
|
|
|
import redis.asyncio as redis
|
|
from pydantic import ValidationError
|
|
|
|
from app.core.config import settings
|
|
from app.schemas.events import ActorType, Event, EventType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EventBusError(Exception):
|
|
"""Base exception for EventBus errors."""
|
|
|
|
|
|
class EventBusConnectionError(EventBusError):
|
|
"""Raised when connection to Redis fails."""
|
|
|
|
|
|
class EventBusPublishError(EventBusError):
|
|
"""Raised when publishing an event fails."""
|
|
|
|
|
|
class EventBusSubscriptionError(EventBusError):
|
|
"""Raised when subscribing to channels fails."""
|
|
|
|
|
|
class EventBus:
|
|
"""
|
|
EventBus for Redis Pub/Sub communication.
|
|
|
|
Provides methods to publish events to channels and subscribe to events
|
|
from multiple channels. Handles connection management, serialization,
|
|
and error recovery.
|
|
|
|
This class provides:
|
|
- Event publishing to project/agent-specific channels
|
|
- Subscription management for SSE endpoints
|
|
- Reconnection support via event IDs (Last-Event-ID)
|
|
- Keepalive messages for connection health
|
|
- Type-safe event creation with the Event schema
|
|
|
|
Attributes:
|
|
redis_url: Redis connection URL
|
|
redis_client: Async Redis client instance
|
|
pubsub: Redis PubSub instance for subscriptions
|
|
"""
|
|
|
|
# Channel prefixes for different entity types
|
|
PROJECT_CHANNEL_PREFIX = "project"
|
|
AGENT_CHANNEL_PREFIX = "agent"
|
|
USER_CHANNEL_PREFIX = "user"
|
|
GLOBAL_CHANNEL = "syndarix:global"
|
|
|
|
def __init__(self, redis_url: str | None = None) -> None:
|
|
"""
|
|
Initialize the EventBus.
|
|
|
|
Args:
|
|
redis_url: Redis connection URL. Defaults to settings.REDIS_URL.
|
|
"""
|
|
self.redis_url = redis_url or settings.REDIS_URL
|
|
self._redis_client: redis.Redis | None = None
|
|
self._pubsub: redis.client.PubSub | None = None
|
|
self._connected = False
|
|
|
|
@property
|
|
def redis_client(self) -> redis.Redis:
|
|
"""Get the Redis client, raising if not connected."""
|
|
if self._redis_client is None:
|
|
raise EventBusConnectionError(
|
|
"EventBus not connected. Call connect() first."
|
|
)
|
|
return self._redis_client
|
|
|
|
@property
|
|
def pubsub(self) -> redis.client.PubSub:
|
|
"""Get the PubSub instance, raising if not connected."""
|
|
if self._pubsub is None:
|
|
raise EventBusConnectionError(
|
|
"EventBus not connected. Call connect() first."
|
|
)
|
|
return self._pubsub
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""Check if the EventBus is connected to Redis."""
|
|
return self._connected and self._redis_client is not None
|
|
|
|
async def connect(self) -> None:
|
|
"""
|
|
Connect to Redis and initialize the PubSub client.
|
|
|
|
Raises:
|
|
EventBusConnectionError: If connection to Redis fails.
|
|
"""
|
|
if self._connected:
|
|
logger.debug("EventBus already connected")
|
|
return
|
|
|
|
try:
|
|
self._redis_client = redis.from_url(
|
|
self.redis_url,
|
|
encoding="utf-8",
|
|
decode_responses=True,
|
|
)
|
|
# Test connection - ping() returns a coroutine for async Redis
|
|
ping_result = self._redis_client.ping()
|
|
if hasattr(ping_result, "__await__"):
|
|
await ping_result
|
|
self._pubsub = self._redis_client.pubsub()
|
|
self._connected = True
|
|
logger.info("EventBus connected to Redis")
|
|
except redis.ConnectionError as e:
|
|
logger.error(f"Failed to connect to Redis: {e}", exc_info=True)
|
|
raise EventBusConnectionError(f"Failed to connect to Redis: {e}") from e
|
|
except redis.RedisError as e:
|
|
logger.error(f"Redis error during connection: {e}", exc_info=True)
|
|
raise EventBusConnectionError(f"Redis error: {e}") from e
|
|
|
|
async def disconnect(self) -> None:
|
|
"""
|
|
Disconnect from Redis and cleanup resources.
|
|
"""
|
|
if self._pubsub:
|
|
try:
|
|
await self._pubsub.unsubscribe()
|
|
await self._pubsub.close()
|
|
except redis.RedisError as e:
|
|
logger.warning(f"Error closing PubSub: {e}")
|
|
finally:
|
|
self._pubsub = None
|
|
|
|
if self._redis_client:
|
|
try:
|
|
await self._redis_client.aclose()
|
|
except redis.RedisError as e:
|
|
logger.warning(f"Error closing Redis client: {e}")
|
|
finally:
|
|
self._redis_client = None
|
|
|
|
self._connected = False
|
|
logger.info("EventBus disconnected from Redis")
|
|
|
|
@asynccontextmanager
|
|
async def connection(self) -> AsyncIterator["EventBus"]:
|
|
"""
|
|
Context manager for automatic connection handling.
|
|
|
|
Usage:
|
|
async with event_bus.connection() as bus:
|
|
await bus.publish(channel, event)
|
|
"""
|
|
await self.connect()
|
|
try:
|
|
yield self
|
|
finally:
|
|
await self.disconnect()
|
|
|
|
def get_project_channel(self, project_id: UUID | str) -> str:
|
|
"""
|
|
Get the channel name for a project.
|
|
|
|
Args:
|
|
project_id: The project UUID or string
|
|
|
|
Returns:
|
|
Channel name string in format "project:{uuid}"
|
|
"""
|
|
return f"{self.PROJECT_CHANNEL_PREFIX}:{project_id}"
|
|
|
|
def get_agent_channel(self, agent_id: UUID | str) -> str:
|
|
"""
|
|
Get the channel name for an agent instance.
|
|
|
|
Args:
|
|
agent_id: The agent instance UUID or string
|
|
|
|
Returns:
|
|
Channel name string in format "agent:{uuid}"
|
|
"""
|
|
return f"{self.AGENT_CHANNEL_PREFIX}:{agent_id}"
|
|
|
|
def get_user_channel(self, user_id: UUID | str) -> str:
|
|
"""
|
|
Get the channel name for a user (personal notifications).
|
|
|
|
Args:
|
|
user_id: The user UUID or string
|
|
|
|
Returns:
|
|
Channel name string in format "user:{uuid}"
|
|
"""
|
|
return f"{self.USER_CHANNEL_PREFIX}:{user_id}"
|
|
|
|
@staticmethod
|
|
def create_event(
|
|
event_type: EventType,
|
|
project_id: UUID,
|
|
actor_type: ActorType,
|
|
payload: dict | None = None,
|
|
actor_id: UUID | None = None,
|
|
event_id: str | None = None,
|
|
timestamp: datetime | None = None,
|
|
) -> Event:
|
|
"""
|
|
Factory method to create a new Event.
|
|
|
|
Args:
|
|
event_type: The type of event
|
|
project_id: The project this event belongs to
|
|
actor_type: Type of actor ('agent', 'user', or 'system')
|
|
payload: Event-specific payload data
|
|
actor_id: ID of the agent or user who triggered the event
|
|
event_id: Optional custom event ID (UUID string)
|
|
timestamp: Optional custom timestamp (defaults to now UTC)
|
|
|
|
Returns:
|
|
A new Event instance
|
|
"""
|
|
return Event(
|
|
id=event_id or str(uuid4()),
|
|
type=event_type,
|
|
timestamp=timestamp or datetime.now(UTC),
|
|
project_id=project_id,
|
|
actor_id=actor_id,
|
|
actor_type=actor_type,
|
|
payload=payload or {},
|
|
)
|
|
|
|
def _serialize_event(self, event: Event) -> str:
|
|
"""
|
|
Serialize an event to JSON string.
|
|
|
|
Args:
|
|
event: The Event to serialize
|
|
|
|
Returns:
|
|
JSON string representation of the event
|
|
"""
|
|
return event.model_dump_json()
|
|
|
|
def _deserialize_event(self, data: str) -> Event:
|
|
"""
|
|
Deserialize a JSON string to an Event.
|
|
|
|
Args:
|
|
data: JSON string to deserialize
|
|
|
|
Returns:
|
|
Deserialized Event instance
|
|
|
|
Raises:
|
|
ValidationError: If the data doesn't match the Event schema
|
|
"""
|
|
return Event.model_validate_json(data)
|
|
|
|
async def publish(self, channel: str, event: Event) -> int:
|
|
"""
|
|
Publish an event to a channel.
|
|
|
|
Args:
|
|
channel: The channel name to publish to
|
|
event: The Event to publish
|
|
|
|
Returns:
|
|
Number of subscribers that received the message
|
|
|
|
Raises:
|
|
EventBusConnectionError: If not connected to Redis
|
|
EventBusPublishError: If publishing fails
|
|
"""
|
|
if not self.is_connected:
|
|
raise EventBusConnectionError("EventBus not connected")
|
|
|
|
try:
|
|
message = self._serialize_event(event)
|
|
subscriber_count = await self.redis_client.publish(channel, message)
|
|
logger.debug(
|
|
f"Published event {event.type} to {channel} "
|
|
f"(received by {subscriber_count} subscribers)"
|
|
)
|
|
return subscriber_count
|
|
except redis.RedisError as e:
|
|
logger.error(f"Failed to publish event to {channel}: {e}", exc_info=True)
|
|
raise EventBusPublishError(f"Failed to publish event: {e}") from e
|
|
|
|
async def publish_to_project(self, event: Event) -> int:
|
|
"""
|
|
Publish an event to the project's channel.
|
|
|
|
Convenience method that publishes to the project channel based on
|
|
the event's project_id.
|
|
|
|
Args:
|
|
event: The Event to publish (must have project_id set)
|
|
|
|
Returns:
|
|
Number of subscribers that received the message
|
|
"""
|
|
channel = self.get_project_channel(event.project_id)
|
|
return await self.publish(channel, event)
|
|
|
|
async def publish_multi(self, channels: list[str], event: Event) -> dict[str, int]:
|
|
"""
|
|
Publish an event to multiple channels.
|
|
|
|
Args:
|
|
channels: List of channel names to publish to
|
|
event: The Event to publish
|
|
|
|
Returns:
|
|
Dictionary mapping channel names to subscriber counts
|
|
"""
|
|
results = {}
|
|
for channel in channels:
|
|
try:
|
|
results[channel] = await self.publish(channel, event)
|
|
except EventBusPublishError as e:
|
|
logger.warning(f"Failed to publish to {channel}: {e}")
|
|
results[channel] = 0
|
|
return results
|
|
|
|
async def subscribe(
|
|
self, channels: list[str], *, max_wait: float | None = None
|
|
) -> AsyncIterator[Event]:
|
|
"""
|
|
Subscribe to one or more channels and yield events.
|
|
|
|
This is an async generator that yields Event objects as they arrive.
|
|
Use max_wait to limit how long to wait for messages.
|
|
|
|
Args:
|
|
channels: List of channel names to subscribe to
|
|
max_wait: Optional maximum wait time in seconds for each message.
|
|
If None, waits indefinitely.
|
|
|
|
Yields:
|
|
Event objects received from subscribed channels
|
|
|
|
Raises:
|
|
EventBusConnectionError: If not connected to Redis
|
|
EventBusSubscriptionError: If subscription fails
|
|
|
|
Example:
|
|
async for event in event_bus.subscribe(["project:123"], max_wait=30):
|
|
print(f"Received: {event.type}")
|
|
"""
|
|
if not self.is_connected:
|
|
raise EventBusConnectionError("EventBus not connected")
|
|
|
|
# Create a new pubsub for this subscription
|
|
subscription_pubsub = self.redis_client.pubsub()
|
|
|
|
try:
|
|
await subscription_pubsub.subscribe(*channels)
|
|
logger.info(f"Subscribed to channels: {channels}")
|
|
except redis.RedisError as e:
|
|
logger.error(f"Failed to subscribe to channels: {e}", exc_info=True)
|
|
await subscription_pubsub.close()
|
|
raise EventBusSubscriptionError(f"Failed to subscribe: {e}") from e
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
if max_wait is not None:
|
|
async with asyncio.timeout(max_wait):
|
|
message = await subscription_pubsub.get_message(
|
|
ignore_subscribe_messages=True, timeout=1.0
|
|
)
|
|
else:
|
|
message = await subscription_pubsub.get_message(
|
|
ignore_subscribe_messages=True, timeout=1.0
|
|
)
|
|
except TimeoutError:
|
|
# Timeout reached, stop iteration
|
|
return
|
|
|
|
if message is None:
|
|
continue
|
|
|
|
if message["type"] == "message":
|
|
try:
|
|
event = self._deserialize_event(message["data"])
|
|
yield event
|
|
except ValidationError as e:
|
|
logger.warning(
|
|
f"Invalid event data received: {e}",
|
|
extra={"channel": message.get("channel")},
|
|
)
|
|
continue
|
|
except json.JSONDecodeError as e:
|
|
logger.warning(
|
|
f"Failed to decode event JSON: {e}",
|
|
extra={"channel": message.get("channel")},
|
|
)
|
|
continue
|
|
finally:
|
|
try:
|
|
await subscription_pubsub.unsubscribe(*channels)
|
|
await subscription_pubsub.close()
|
|
logger.debug(f"Unsubscribed from channels: {channels}")
|
|
except redis.RedisError as e:
|
|
logger.warning(f"Error unsubscribing from channels: {e}")
|
|
|
|
async def subscribe_sse(
|
|
self,
|
|
project_id: str | UUID,
|
|
last_event_id: str | None = None,
|
|
keepalive_interval: int = 30,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""
|
|
Subscribe to events for a project in SSE format.
|
|
|
|
This is an async generator that yields SSE-formatted event strings.
|
|
It includes keepalive messages at the specified interval.
|
|
|
|
Args:
|
|
project_id: The project to subscribe to
|
|
last_event_id: Optional last received event ID for reconnection
|
|
keepalive_interval: Seconds between keepalive messages (default 30)
|
|
|
|
Yields:
|
|
SSE-formatted event strings (ready to send to client)
|
|
"""
|
|
if not self.is_connected:
|
|
raise EventBusConnectionError("EventBus not connected")
|
|
|
|
project_id_str = str(project_id)
|
|
channel = self.get_project_channel(project_id_str)
|
|
|
|
subscription_pubsub = self.redis_client.pubsub()
|
|
await subscription_pubsub.subscribe(channel)
|
|
|
|
logger.info(
|
|
f"Subscribed to SSE events for project {project_id_str} "
|
|
f"(last_event_id={last_event_id})"
|
|
)
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
# Wait for messages with a timeout for keepalive
|
|
message = await asyncio.wait_for(
|
|
subscription_pubsub.get_message(ignore_subscribe_messages=True),
|
|
timeout=keepalive_interval,
|
|
)
|
|
|
|
if message is not None and message["type"] == "message":
|
|
event_data = message["data"]
|
|
|
|
# If reconnecting, check if we should skip this event
|
|
if last_event_id:
|
|
try:
|
|
event_dict = json.loads(event_data)
|
|
if event_dict.get("id") == last_event_id:
|
|
# Found the last event, start yielding from next
|
|
last_event_id = None
|
|
continue
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
yield event_data
|
|
|
|
except TimeoutError:
|
|
# Send keepalive comment
|
|
yield "" # Empty string signals keepalive
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info(f"SSE subscription cancelled for project {project_id_str}")
|
|
raise
|
|
finally:
|
|
await subscription_pubsub.unsubscribe(channel)
|
|
await subscription_pubsub.close()
|
|
logger.info(f"Unsubscribed SSE from project {project_id_str}")
|
|
|
|
async def subscribe_with_callback(
|
|
self,
|
|
channels: list[str],
|
|
callback: Any, # Callable[[Event], Awaitable[None]]
|
|
stop_event: asyncio.Event | None = None,
|
|
) -> None:
|
|
"""
|
|
Subscribe to channels and process events with a callback.
|
|
|
|
This method runs until stop_event is set or an unrecoverable error occurs.
|
|
|
|
Args:
|
|
channels: List of channel names to subscribe to
|
|
callback: Async function to call for each event
|
|
stop_event: Optional asyncio.Event to signal stop
|
|
|
|
Example:
|
|
async def handle_event(event: Event):
|
|
print(f"Handling: {event.type}")
|
|
|
|
stop = asyncio.Event()
|
|
asyncio.create_task(
|
|
event_bus.subscribe_with_callback(["project:123"], handle_event, stop)
|
|
)
|
|
# Later...
|
|
stop.set()
|
|
"""
|
|
if stop_event is None:
|
|
stop_event = asyncio.Event()
|
|
|
|
try:
|
|
async for event in self.subscribe(channels):
|
|
if stop_event.is_set():
|
|
break
|
|
try:
|
|
await callback(event)
|
|
except Exception as e:
|
|
logger.error(f"Error in event callback: {e}", exc_info=True)
|
|
except EventBusSubscriptionError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error in subscription loop: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
# Singleton instance for application-wide use
|
|
_event_bus: EventBus | None = None
|
|
|
|
|
|
def get_event_bus() -> EventBus:
|
|
"""
|
|
Get the singleton EventBus instance.
|
|
|
|
Creates a new instance if one doesn't exist. Note that you still need
|
|
to call connect() before using the EventBus.
|
|
|
|
Returns:
|
|
The singleton EventBus instance
|
|
"""
|
|
global _event_bus
|
|
if _event_bus is None:
|
|
_event_bus = EventBus()
|
|
return _event_bus
|
|
|
|
|
|
async def get_connected_event_bus() -> EventBus:
|
|
"""
|
|
Get a connected EventBus instance.
|
|
|
|
Ensures the EventBus is connected before returning. For use in
|
|
FastAPI dependency injection.
|
|
|
|
Returns:
|
|
A connected EventBus instance
|
|
|
|
Raises:
|
|
EventBusConnectionError: If connection fails
|
|
"""
|
|
event_bus = get_event_bus()
|
|
if not event_bus.is_connected:
|
|
await event_bus.connect()
|
|
return event_bus
|
|
|
|
|
|
async def close_event_bus() -> None:
|
|
"""
|
|
Close the global EventBus instance.
|
|
|
|
Should be called during application shutdown.
|
|
"""
|
|
global _event_bus
|
|
if _event_bus is not None:
|
|
await _event_bus.disconnect()
|
|
_event_bus = None
|