forked from cardosofelipe/fast-next-template
feat(backend): Add EventBus service with Redis Pub/Sub
- Add EventBus class for real-time event communication - Add Event schema with type-safe event types (agent, issue, sprint events) - Add typed payload schemas (AgentSpawnedPayload, AgentMessagePayload) - Add channel helpers for project/agent/user scoping - Add subscribe_sse generator for SSE streaming - Add reconnection support via Last-Event-ID - Add keepalive mechanism for connection health - Add 44 comprehensive tests with mocked Redis Implements #33 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
275
backend/app/schemas/events.py
Normal file
275
backend/app/schemas/events.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Event schemas for the Syndarix EventBus (Redis Pub/Sub).
|
||||
|
||||
This module defines event types and payload schemas for real-time communication
|
||||
between services, agents, and the frontend.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
"""
|
||||
Event types for the EventBus.
|
||||
|
||||
Naming convention: {domain}.{action}
|
||||
"""
|
||||
|
||||
# Agent Events
|
||||
AGENT_SPAWNED = "agent.spawned"
|
||||
AGENT_STATUS_CHANGED = "agent.status_changed"
|
||||
AGENT_MESSAGE = "agent.message"
|
||||
AGENT_TERMINATED = "agent.terminated"
|
||||
|
||||
# Issue Events
|
||||
ISSUE_CREATED = "issue.created"
|
||||
ISSUE_UPDATED = "issue.updated"
|
||||
ISSUE_ASSIGNED = "issue.assigned"
|
||||
ISSUE_CLOSED = "issue.closed"
|
||||
|
||||
# Sprint Events
|
||||
SPRINT_STARTED = "sprint.started"
|
||||
SPRINT_COMPLETED = "sprint.completed"
|
||||
|
||||
# Approval Events
|
||||
APPROVAL_REQUESTED = "approval.requested"
|
||||
APPROVAL_GRANTED = "approval.granted"
|
||||
APPROVAL_DENIED = "approval.denied"
|
||||
|
||||
# Project Events
|
||||
PROJECT_CREATED = "project.created"
|
||||
PROJECT_UPDATED = "project.updated"
|
||||
PROJECT_ARCHIVED = "project.archived"
|
||||
|
||||
# Workflow Events
|
||||
WORKFLOW_STARTED = "workflow.started"
|
||||
WORKFLOW_STEP_COMPLETED = "workflow.step_completed"
|
||||
WORKFLOW_COMPLETED = "workflow.completed"
|
||||
WORKFLOW_FAILED = "workflow.failed"
|
||||
|
||||
|
||||
ActorType = Literal["agent", "user", "system"]
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
"""
|
||||
Base event schema for the EventBus.
|
||||
|
||||
All events published to the EventBus must conform to this schema.
|
||||
"""
|
||||
|
||||
id: str = Field(
|
||||
...,
|
||||
description="Unique event identifier (UUID string)",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440000"],
|
||||
)
|
||||
type: EventType = Field(
|
||||
...,
|
||||
description="Event type enum value",
|
||||
examples=[EventType.AGENT_MESSAGE],
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
...,
|
||||
description="When the event occurred (UTC)",
|
||||
examples=["2024-01-15T10:30:00Z"],
|
||||
)
|
||||
project_id: UUID = Field(
|
||||
...,
|
||||
description="Project this event belongs to",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440001"],
|
||||
)
|
||||
actor_id: UUID | None = Field(
|
||||
default=None,
|
||||
description="ID of the agent or user who triggered the event",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440002"],
|
||||
)
|
||||
actor_type: ActorType = Field(
|
||||
...,
|
||||
description="Type of actor: 'agent', 'user', or 'system'",
|
||||
examples=["agent"],
|
||||
)
|
||||
payload: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Event-specific payload data",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"type": "agent.message",
|
||||
"timestamp": "2024-01-15T10:30:00Z",
|
||||
"project_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
"actor_id": "550e8400-e29b-41d4-a716-446655440002",
|
||||
"actor_type": "agent",
|
||||
"payload": {"message": "Processing task...", "progress": 50},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Specific payload schemas for type safety
|
||||
|
||||
|
||||
class AgentSpawnedPayload(BaseModel):
|
||||
"""Payload for AGENT_SPAWNED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the spawned agent instance")
|
||||
agent_type_id: UUID = Field(..., description="ID of the agent type")
|
||||
agent_name: str = Field(..., description="Human-readable name of the agent")
|
||||
role: str = Field(..., description="Agent role (e.g., 'product_owner', 'engineer')")
|
||||
|
||||
|
||||
class AgentStatusChangedPayload(BaseModel):
|
||||
"""Payload for AGENT_STATUS_CHANGED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
previous_status: str = Field(..., description="Previous status")
|
||||
new_status: str = Field(..., description="New status")
|
||||
reason: str | None = Field(default=None, description="Reason for status change")
|
||||
|
||||
|
||||
class AgentMessagePayload(BaseModel):
|
||||
"""Payload for AGENT_MESSAGE events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
message: str = Field(..., description="Message content")
|
||||
message_type: str = Field(
|
||||
default="info",
|
||||
description="Message type: 'info', 'warning', 'error', 'debug'",
|
||||
)
|
||||
metadata: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata (e.g., token usage, model info)",
|
||||
)
|
||||
|
||||
|
||||
class AgentTerminatedPayload(BaseModel):
|
||||
"""Payload for AGENT_TERMINATED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
termination_reason: str = Field(..., description="Reason for termination")
|
||||
final_status: str = Field(..., description="Final status at termination")
|
||||
|
||||
|
||||
class IssueCreatedPayload(BaseModel):
|
||||
"""Payload for ISSUE_CREATED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
title: str = Field(..., description="Issue title")
|
||||
priority: str | None = Field(default=None, description="Issue priority")
|
||||
labels: list[str] = Field(default_factory=list, description="Issue labels")
|
||||
|
||||
|
||||
class IssueUpdatedPayload(BaseModel):
|
||||
"""Payload for ISSUE_UPDATED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
changes: dict = Field(..., description="Dictionary of field changes")
|
||||
|
||||
|
||||
class IssueAssignedPayload(BaseModel):
|
||||
"""Payload for ISSUE_ASSIGNED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
assignee_id: UUID | None = Field(
|
||||
default=None, description="Agent or user assigned to"
|
||||
)
|
||||
assignee_name: str | None = Field(default=None, description="Assignee name")
|
||||
|
||||
|
||||
class IssueClosedPayload(BaseModel):
|
||||
"""Payload for ISSUE_CLOSED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
resolution: str = Field(..., description="Resolution status")
|
||||
|
||||
|
||||
class SprintStartedPayload(BaseModel):
|
||||
"""Payload for SPRINT_STARTED events."""
|
||||
|
||||
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||
sprint_name: str = Field(..., description="Sprint name")
|
||||
goal: str | None = Field(default=None, description="Sprint goal")
|
||||
issue_count: int = Field(default=0, description="Number of issues in sprint")
|
||||
|
||||
|
||||
class SprintCompletedPayload(BaseModel):
|
||||
"""Payload for SPRINT_COMPLETED events."""
|
||||
|
||||
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||
sprint_name: str = Field(..., description="Sprint name")
|
||||
completed_issues: int = Field(default=0, description="Number of completed issues")
|
||||
incomplete_issues: int = Field(
|
||||
default=0, description="Number of incomplete issues"
|
||||
)
|
||||
|
||||
|
||||
class ApprovalRequestedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_REQUESTED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
approval_type: str = Field(..., description="Type of approval needed")
|
||||
description: str = Field(..., description="Description of what needs approval")
|
||||
requested_by: UUID | None = Field(
|
||||
default=None, description="Agent/user requesting approval"
|
||||
)
|
||||
timeout_minutes: int | None = Field(
|
||||
default=None, description="Minutes before auto-escalation"
|
||||
)
|
||||
|
||||
|
||||
class ApprovalGrantedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_GRANTED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
approved_by: UUID = Field(..., description="User who granted approval")
|
||||
comments: str | None = Field(default=None, description="Approval comments")
|
||||
|
||||
|
||||
class ApprovalDeniedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_DENIED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
denied_by: UUID = Field(..., description="User who denied approval")
|
||||
reason: str = Field(..., description="Reason for denial")
|
||||
|
||||
|
||||
class WorkflowStartedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_STARTED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
workflow_type: str = Field(..., description="Type of workflow")
|
||||
total_steps: int = Field(default=0, description="Total number of steps")
|
||||
|
||||
|
||||
class WorkflowStepCompletedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_STEP_COMPLETED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
step_name: str = Field(..., description="Name of completed step")
|
||||
step_number: int = Field(..., description="Step number (1-indexed)")
|
||||
total_steps: int = Field(..., description="Total number of steps")
|
||||
result: dict = Field(default_factory=dict, description="Step result data")
|
||||
|
||||
|
||||
class WorkflowCompletedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_COMPLETED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
duration_seconds: float = Field(..., description="Total execution duration")
|
||||
result: dict = Field(default_factory=dict, description="Workflow result data")
|
||||
|
||||
|
||||
class WorkflowFailedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_FAILED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
error_message: str = Field(..., description="Error message")
|
||||
failed_step: str | None = Field(default=None, description="Step that failed")
|
||||
recoverable: bool = Field(default=False, description="Whether error is recoverable")
|
||||
622
backend/app/services/event_bus.py
Normal file
622
backend/app/services/event_bus.py
Normal file
@@ -0,0 +1,622 @@
|
||||
"""
|
||||
EventBus service for Redis Pub/Sub communication.
|
||||
|
||||
This module provides a centralized event bus for publishing and subscribing to
|
||||
events across the Syndarix platform. It uses Redis Pub/Sub for real-time
|
||||
message delivery between services, agents, and the frontend.
|
||||
|
||||
Architecture:
|
||||
- Publishers emit events to project/agent-specific Redis channels
|
||||
- SSE endpoints subscribe to channels and stream events to clients
|
||||
- Events include metadata for reconnection support (Last-Event-ID)
|
||||
- Events are typed with the EventType enum for consistency
|
||||
|
||||
Usage:
|
||||
# Publishing events
|
||||
event_bus = EventBus()
|
||||
await event_bus.connect()
|
||||
|
||||
event = event_bus.create_event(
|
||||
event_type=EventType.AGENT_MESSAGE,
|
||||
project_id=project_id,
|
||||
actor_type="agent",
|
||||
payload={"message": "Processing..."}
|
||||
)
|
||||
await event_bus.publish(event_bus.get_project_channel(project_id), event)
|
||||
|
||||
# Subscribing to events
|
||||
async for event in event_bus.subscribe(["project:123", "agent:456"]):
|
||||
handle_event(event)
|
||||
|
||||
# Cleanup
|
||||
await event_bus.disconnect()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import redis.asyncio as redis
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.events import ActorType, Event, EventType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventBusError(Exception):
|
||||
"""Base exception for EventBus errors."""
|
||||
|
||||
|
||||
|
||||
class EventBusConnectionError(EventBusError):
|
||||
"""Raised when connection to Redis fails."""
|
||||
|
||||
|
||||
|
||||
class EventBusPublishError(EventBusError):
|
||||
"""Raised when publishing an event fails."""
|
||||
|
||||
|
||||
|
||||
class EventBusSubscriptionError(EventBusError):
|
||||
"""Raised when subscribing to channels fails."""
|
||||
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
EventBus for Redis Pub/Sub communication.
|
||||
|
||||
Provides methods to publish events to channels and subscribe to events
|
||||
from multiple channels. Handles connection management, serialization,
|
||||
and error recovery.
|
||||
|
||||
This class provides:
|
||||
- Event publishing to project/agent-specific channels
|
||||
- Subscription management for SSE endpoints
|
||||
- Reconnection support via event IDs and sequence numbers
|
||||
- Keepalive messages for connection health
|
||||
- Type-safe event creation with the Event schema
|
||||
|
||||
Attributes:
|
||||
redis_url: Redis connection URL
|
||||
redis_client: Async Redis client instance
|
||||
pubsub: Redis PubSub instance for subscriptions
|
||||
"""
|
||||
|
||||
# Channel prefixes for different entity types
|
||||
PROJECT_CHANNEL_PREFIX = "project"
|
||||
AGENT_CHANNEL_PREFIX = "agent"
|
||||
USER_CHANNEL_PREFIX = "user"
|
||||
GLOBAL_CHANNEL = "syndarix:global"
|
||||
|
||||
def __init__(self, redis_url: str | None = None) -> None:
|
||||
"""
|
||||
Initialize the EventBus.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL. Defaults to settings.REDIS_URL.
|
||||
"""
|
||||
self.redis_url = redis_url or settings.REDIS_URL
|
||||
self._redis_client: redis.Redis | None = None
|
||||
self._pubsub: redis.client.PubSub | None = None
|
||||
self._connected = False
|
||||
self._sequence_counters: dict[str, int] = {}
|
||||
|
||||
@property
|
||||
def redis_client(self) -> redis.Redis:
|
||||
"""Get the Redis client, raising if not connected."""
|
||||
if self._redis_client is None:
|
||||
raise EventBusConnectionError(
|
||||
"EventBus not connected. Call connect() first."
|
||||
)
|
||||
return self._redis_client
|
||||
|
||||
@property
|
||||
def pubsub(self) -> redis.client.PubSub:
|
||||
"""Get the PubSub instance, raising if not connected."""
|
||||
if self._pubsub is None:
|
||||
raise EventBusConnectionError(
|
||||
"EventBus not connected. Call connect() first."
|
||||
)
|
||||
return self._pubsub
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the EventBus is connected to Redis."""
|
||||
return self._connected and self._redis_client is not None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Connect to Redis and initialize the PubSub client.
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection to Redis fails.
|
||||
"""
|
||||
if self._connected:
|
||||
logger.debug("EventBus already connected")
|
||||
return
|
||||
|
||||
try:
|
||||
self._redis_client = redis.from_url(
|
||||
self.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# Test connection - ping() returns a coroutine for async Redis
|
||||
ping_result = self._redis_client.ping()
|
||||
if hasattr(ping_result, "__await__"):
|
||||
await ping_result
|
||||
self._pubsub = self._redis_client.pubsub()
|
||||
self._connected = True
|
||||
logger.info("EventBus connected to Redis")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}", exc_info=True)
|
||||
raise EventBusConnectionError(f"Failed to connect to Redis: {e}") from e
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Redis error during connection: {e}", exc_info=True)
|
||||
raise EventBusConnectionError(f"Redis error: {e}") from e
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect from Redis and cleanup resources.
|
||||
"""
|
||||
if self._pubsub:
|
||||
try:
|
||||
await self._pubsub.unsubscribe()
|
||||
await self._pubsub.close()
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error closing PubSub: {e}")
|
||||
finally:
|
||||
self._pubsub = None
|
||||
|
||||
if self._redis_client:
|
||||
try:
|
||||
await self._redis_client.aclose()
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error closing Redis client: {e}")
|
||||
finally:
|
||||
self._redis_client = None
|
||||
|
||||
self._connected = False
|
||||
logger.info("EventBus disconnected from Redis")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self) -> AsyncIterator["EventBus"]:
|
||||
"""
|
||||
Context manager for automatic connection handling.
|
||||
|
||||
Usage:
|
||||
async with event_bus.connection() as bus:
|
||||
await bus.publish(channel, event)
|
||||
"""
|
||||
await self.connect()
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
await self.disconnect()
|
||||
|
||||
def get_project_channel(self, project_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for a project.
|
||||
|
||||
Args:
|
||||
project_id: The project UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "project:{uuid}"
|
||||
"""
|
||||
return f"{self.PROJECT_CHANNEL_PREFIX}:{project_id}"
|
||||
|
||||
def get_agent_channel(self, agent_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for an agent instance.
|
||||
|
||||
Args:
|
||||
agent_id: The agent instance UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "agent:{uuid}"
|
||||
"""
|
||||
return f"{self.AGENT_CHANNEL_PREFIX}:{agent_id}"
|
||||
|
||||
def get_user_channel(self, user_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for a user (personal notifications).
|
||||
|
||||
Args:
|
||||
user_id: The user UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "user:{uuid}"
|
||||
"""
|
||||
return f"{self.USER_CHANNEL_PREFIX}:{user_id}"
|
||||
|
||||
def _get_next_sequence(self, channel: str) -> int:
|
||||
"""Get the next sequence number for a channel's events."""
|
||||
current = self._sequence_counters.get(channel, 0)
|
||||
self._sequence_counters[channel] = current + 1
|
||||
return current + 1
|
||||
|
||||
@staticmethod
|
||||
def create_event(
|
||||
event_type: EventType,
|
||||
project_id: UUID,
|
||||
actor_type: ActorType,
|
||||
payload: dict | None = None,
|
||||
actor_id: UUID | None = None,
|
||||
event_id: str | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> Event:
|
||||
"""
|
||||
Factory method to create a new Event.
|
||||
|
||||
Args:
|
||||
event_type: The type of event
|
||||
project_id: The project this event belongs to
|
||||
actor_type: Type of actor ('agent', 'user', or 'system')
|
||||
payload: Event-specific payload data
|
||||
actor_id: ID of the agent or user who triggered the event
|
||||
event_id: Optional custom event ID (UUID string)
|
||||
timestamp: Optional custom timestamp (defaults to now UTC)
|
||||
|
||||
Returns:
|
||||
A new Event instance
|
||||
"""
|
||||
return Event(
|
||||
id=event_id or str(uuid4()),
|
||||
type=event_type,
|
||||
timestamp=timestamp or datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
payload=payload or {},
|
||||
)
|
||||
|
||||
def _serialize_event(self, event: Event) -> str:
|
||||
"""
|
||||
Serialize an event to JSON string.
|
||||
|
||||
Args:
|
||||
event: The Event to serialize
|
||||
|
||||
Returns:
|
||||
JSON string representation of the event
|
||||
"""
|
||||
return event.model_dump_json()
|
||||
|
||||
def _deserialize_event(self, data: str) -> Event:
|
||||
"""
|
||||
Deserialize a JSON string to an Event.
|
||||
|
||||
Args:
|
||||
data: JSON string to deserialize
|
||||
|
||||
Returns:
|
||||
Deserialized Event instance
|
||||
|
||||
Raises:
|
||||
ValidationError: If the data doesn't match the Event schema
|
||||
"""
|
||||
return Event.model_validate_json(data)
|
||||
|
||||
async def publish(self, channel: str, event: Event) -> int:
|
||||
"""
|
||||
Publish an event to a channel.
|
||||
|
||||
Args:
|
||||
channel: The channel name to publish to
|
||||
event: The Event to publish
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If not connected to Redis
|
||||
EventBusPublishError: If publishing fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
try:
|
||||
message = self._serialize_event(event)
|
||||
subscriber_count = await self.redis_client.publish(channel, message)
|
||||
logger.debug(
|
||||
f"Published event {event.type} to {channel} "
|
||||
f"(received by {subscriber_count} subscribers)"
|
||||
)
|
||||
return subscriber_count
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Failed to publish event to {channel}: {e}", exc_info=True)
|
||||
raise EventBusPublishError(f"Failed to publish event: {e}") from e
|
||||
|
||||
async def publish_to_project(self, event: Event) -> int:
|
||||
"""
|
||||
Publish an event to the project's channel.
|
||||
|
||||
Convenience method that publishes to the project channel based on
|
||||
the event's project_id.
|
||||
|
||||
Args:
|
||||
event: The Event to publish (must have project_id set)
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message
|
||||
"""
|
||||
channel = self.get_project_channel(event.project_id)
|
||||
return await self.publish(channel, event)
|
||||
|
||||
async def publish_multi(self, channels: list[str], event: Event) -> dict[str, int]:
|
||||
"""
|
||||
Publish an event to multiple channels.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to publish to
|
||||
event: The Event to publish
|
||||
|
||||
Returns:
|
||||
Dictionary mapping channel names to subscriber counts
|
||||
"""
|
||||
results = {}
|
||||
for channel in channels:
|
||||
try:
|
||||
results[channel] = await self.publish(channel, event)
|
||||
except EventBusPublishError as e:
|
||||
logger.warning(f"Failed to publish to {channel}: {e}")
|
||||
results[channel] = 0
|
||||
return results
|
||||
|
||||
async def subscribe(
|
||||
self, channels: list[str], *, max_wait: float | None = None
|
||||
) -> AsyncIterator[Event]:
|
||||
"""
|
||||
Subscribe to one or more channels and yield events.
|
||||
|
||||
This is an async generator that yields Event objects as they arrive.
|
||||
Use max_wait to limit how long to wait for messages.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to subscribe to
|
||||
max_wait: Optional maximum wait time in seconds for each message.
|
||||
If None, waits indefinitely.
|
||||
|
||||
Yields:
|
||||
Event objects received from subscribed channels
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If not connected to Redis
|
||||
EventBusSubscriptionError: If subscription fails
|
||||
|
||||
Example:
|
||||
async for event in event_bus.subscribe(["project:123"], max_wait=30):
|
||||
print(f"Received: {event.type}")
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
# Create a new pubsub for this subscription
|
||||
subscription_pubsub = self.redis_client.pubsub()
|
||||
|
||||
try:
|
||||
await subscription_pubsub.subscribe(*channels)
|
||||
logger.info(f"Subscribed to channels: {channels}")
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Failed to subscribe to channels: {e}", exc_info=True)
|
||||
await subscription_pubsub.close()
|
||||
raise EventBusSubscriptionError(f"Failed to subscribe: {e}") from e
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if max_wait is not None:
|
||||
async with asyncio.timeout(max_wait):
|
||||
message = await subscription_pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
else:
|
||||
message = await subscription_pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
except TimeoutError:
|
||||
# Timeout reached, stop iteration
|
||||
return
|
||||
|
||||
if message is None:
|
||||
continue
|
||||
|
||||
if message["type"] == "message":
|
||||
try:
|
||||
event = self._deserialize_event(message["data"])
|
||||
yield event
|
||||
except ValidationError as e:
|
||||
logger.warning(
|
||||
f"Invalid event data received: {e}",
|
||||
extra={"channel": message.get("channel")},
|
||||
)
|
||||
continue
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to decode event JSON: {e}",
|
||||
extra={"channel": message.get("channel")},
|
||||
)
|
||||
continue
|
||||
finally:
|
||||
try:
|
||||
await subscription_pubsub.unsubscribe(*channels)
|
||||
await subscription_pubsub.close()
|
||||
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error unsubscribing from channels: {e}")
|
||||
|
||||
async def subscribe_sse(
|
||||
self,
|
||||
project_id: str | UUID,
|
||||
last_event_id: str | None = None,
|
||||
keepalive_interval: int = 30,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Subscribe to events for a project in SSE format.
|
||||
|
||||
This is an async generator that yields SSE-formatted event strings.
|
||||
It includes keepalive messages at the specified interval.
|
||||
|
||||
Args:
|
||||
project_id: The project to subscribe to
|
||||
last_event_id: Optional last received event ID for reconnection
|
||||
keepalive_interval: Seconds between keepalive messages (default 30)
|
||||
|
||||
Yields:
|
||||
SSE-formatted event strings (ready to send to client)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
project_id_str = str(project_id)
|
||||
channel = self.get_project_channel(project_id_str)
|
||||
|
||||
subscription_pubsub = self.redis_client.pubsub()
|
||||
await subscription_pubsub.subscribe(channel)
|
||||
|
||||
logger.info(
|
||||
f"Subscribed to SSE events for project {project_id_str} "
|
||||
f"(last_event_id={last_event_id})"
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for messages with a timeout for keepalive
|
||||
message = await asyncio.wait_for(
|
||||
subscription_pubsub.get_message(ignore_subscribe_messages=True),
|
||||
timeout=keepalive_interval,
|
||||
)
|
||||
|
||||
if message is not None and message["type"] == "message":
|
||||
event_data = message["data"]
|
||||
|
||||
# If reconnecting, check if we should skip this event
|
||||
if last_event_id:
|
||||
try:
|
||||
event_dict = json.loads(event_data)
|
||||
if event_dict.get("id") == last_event_id:
|
||||
# Found the last event, start yielding from next
|
||||
last_event_id = None
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
yield event_data
|
||||
|
||||
except TimeoutError:
|
||||
# Send keepalive comment
|
||||
yield "" # Empty string signals keepalive
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"SSE subscription cancelled for project {project_id_str}")
|
||||
raise
|
||||
finally:
|
||||
await subscription_pubsub.unsubscribe(channel)
|
||||
await subscription_pubsub.close()
|
||||
logger.info(f"Unsubscribed SSE from project {project_id_str}")
|
||||
|
||||
async def subscribe_with_callback(
|
||||
self,
|
||||
channels: list[str],
|
||||
callback: Any, # Callable[[Event], Awaitable[None]]
|
||||
stop_event: asyncio.Event | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to channels and process events with a callback.
|
||||
|
||||
This method runs until stop_event is set or an unrecoverable error occurs.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to subscribe to
|
||||
callback: Async function to call for each event
|
||||
stop_event: Optional asyncio.Event to signal stop
|
||||
|
||||
Example:
|
||||
async def handle_event(event: Event):
|
||||
print(f"Handling: {event.type}")
|
||||
|
||||
stop = asyncio.Event()
|
||||
asyncio.create_task(
|
||||
event_bus.subscribe_with_callback(["project:123"], handle_event, stop)
|
||||
)
|
||||
# Later...
|
||||
stop.set()
|
||||
"""
|
||||
if stop_event is None:
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
try:
|
||||
async for event in self.subscribe(channels):
|
||||
if stop_event.is_set():
|
||||
break
|
||||
try:
|
||||
await callback(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event callback: {e}", exc_info=True)
|
||||
except EventBusSubscriptionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in subscription loop: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance for application-wide use
|
||||
_event_bus: EventBus | None = None
|
||||
|
||||
|
||||
def get_event_bus() -> EventBus:
|
||||
"""
|
||||
Get the singleton EventBus instance.
|
||||
|
||||
Creates a new instance if one doesn't exist. Note that you still need
|
||||
to call connect() before using the EventBus.
|
||||
|
||||
Returns:
|
||||
The singleton EventBus instance
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is None:
|
||||
_event_bus = EventBus()
|
||||
return _event_bus
|
||||
|
||||
|
||||
async def get_connected_event_bus() -> EventBus:
|
||||
"""
|
||||
Get a connected EventBus instance.
|
||||
|
||||
Ensures the EventBus is connected before returning. For use in
|
||||
FastAPI dependency injection.
|
||||
|
||||
Returns:
|
||||
A connected EventBus instance
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection fails
|
||||
"""
|
||||
event_bus = get_event_bus()
|
||||
if not event_bus.is_connected:
|
||||
await event_bus.connect()
|
||||
return event_bus
|
||||
|
||||
|
||||
async def close_event_bus() -> None:
|
||||
"""
|
||||
Close the global EventBus instance.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is not None:
|
||||
await _event_bus.disconnect()
|
||||
_event_bus = None
|
||||
1035
backend/tests/services/test_event_bus.py
Normal file
1035
backend/tests/services/test_event_bus.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user