From 3c24a8c522435acbb9ea6ca0067f5e032e49a46e Mon Sep 17 00:00:00 2001 From: Felipe Cardoso Date: Tue, 30 Dec 2025 02:07:51 +0100 Subject: [PATCH] feat(backend): Add EventBus service with Redis Pub/Sub MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add EventBus class for real-time event communication - Add Event schema with type-safe event types (agent, issue, sprint events) - Add typed payload schemas (AgentSpawnedPayload, AgentMessagePayload) - Add channel helpers for project/agent/user scoping - Add subscribe_sse generator for SSE streaming - Add reconnection support via Last-Event-ID - Add keepalive mechanism for connection health - Add 44 comprehensive tests with mocked Redis Implements #33 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- backend/app/schemas/events.py | 275 ++++++ backend/app/services/event_bus.py | 622 +++++++++++++ backend/tests/services/test_event_bus.py | 1035 ++++++++++++++++++++++ 3 files changed, 1932 insertions(+) create mode 100644 backend/app/schemas/events.py create mode 100644 backend/app/services/event_bus.py create mode 100644 backend/tests/services/test_event_bus.py diff --git a/backend/app/schemas/events.py b/backend/app/schemas/events.py new file mode 100644 index 0000000..a7a8a65 --- /dev/null +++ b/backend/app/schemas/events.py @@ -0,0 +1,275 @@ +""" +Event schemas for the Syndarix EventBus (Redis Pub/Sub). + +This module defines event types and payload schemas for real-time communication +between services, agents, and the frontend. +""" + +from datetime import datetime +from enum import Enum +from typing import Literal +from uuid import UUID + +from pydantic import BaseModel, Field + + +class EventType(str, Enum): + """ + Event types for the EventBus. + + Naming convention: {domain}.{action} + """ + + # Agent Events + AGENT_SPAWNED = "agent.spawned" + AGENT_STATUS_CHANGED = "agent.status_changed" + AGENT_MESSAGE = "agent.message" + AGENT_TERMINATED = "agent.terminated" + + # Issue Events + ISSUE_CREATED = "issue.created" + ISSUE_UPDATED = "issue.updated" + ISSUE_ASSIGNED = "issue.assigned" + ISSUE_CLOSED = "issue.closed" + + # Sprint Events + SPRINT_STARTED = "sprint.started" + SPRINT_COMPLETED = "sprint.completed" + + # Approval Events + APPROVAL_REQUESTED = "approval.requested" + APPROVAL_GRANTED = "approval.granted" + APPROVAL_DENIED = "approval.denied" + + # Project Events + PROJECT_CREATED = "project.created" + PROJECT_UPDATED = "project.updated" + PROJECT_ARCHIVED = "project.archived" + + # Workflow Events + WORKFLOW_STARTED = "workflow.started" + WORKFLOW_STEP_COMPLETED = "workflow.step_completed" + WORKFLOW_COMPLETED = "workflow.completed" + WORKFLOW_FAILED = "workflow.failed" + + +ActorType = Literal["agent", "user", "system"] + + +class Event(BaseModel): + """ + Base event schema for the EventBus. + + All events published to the EventBus must conform to this schema. + """ + + id: str = Field( + ..., + description="Unique event identifier (UUID string)", + examples=["550e8400-e29b-41d4-a716-446655440000"], + ) + type: EventType = Field( + ..., + description="Event type enum value", + examples=[EventType.AGENT_MESSAGE], + ) + timestamp: datetime = Field( + ..., + description="When the event occurred (UTC)", + examples=["2024-01-15T10:30:00Z"], + ) + project_id: UUID = Field( + ..., + description="Project this event belongs to", + examples=["550e8400-e29b-41d4-a716-446655440001"], + ) + actor_id: UUID | None = Field( + default=None, + description="ID of the agent or user who triggered the event", + examples=["550e8400-e29b-41d4-a716-446655440002"], + ) + actor_type: ActorType = Field( + ..., + description="Type of actor: 'agent', 'user', or 'system'", + examples=["agent"], + ) + payload: dict = Field( + default_factory=dict, + description="Event-specific payload data", + ) + + model_config = { + "json_schema_extra": { + "example": { + "id": "550e8400-e29b-41d4-a716-446655440000", + "type": "agent.message", + "timestamp": "2024-01-15T10:30:00Z", + "project_id": "550e8400-e29b-41d4-a716-446655440001", + "actor_id": "550e8400-e29b-41d4-a716-446655440002", + "actor_type": "agent", + "payload": {"message": "Processing task...", "progress": 50}, + } + } + } + + +# Specific payload schemas for type safety + + +class AgentSpawnedPayload(BaseModel): + """Payload for AGENT_SPAWNED events.""" + + agent_instance_id: UUID = Field(..., description="ID of the spawned agent instance") + agent_type_id: UUID = Field(..., description="ID of the agent type") + agent_name: str = Field(..., description="Human-readable name of the agent") + role: str = Field(..., description="Agent role (e.g., 'product_owner', 'engineer')") + + +class AgentStatusChangedPayload(BaseModel): + """Payload for AGENT_STATUS_CHANGED events.""" + + agent_instance_id: UUID = Field(..., description="ID of the agent instance") + previous_status: str = Field(..., description="Previous status") + new_status: str = Field(..., description="New status") + reason: str | None = Field(default=None, description="Reason for status change") + + +class AgentMessagePayload(BaseModel): + """Payload for AGENT_MESSAGE events.""" + + agent_instance_id: UUID = Field(..., description="ID of the agent instance") + message: str = Field(..., description="Message content") + message_type: str = Field( + default="info", + description="Message type: 'info', 'warning', 'error', 'debug'", + ) + metadata: dict = Field( + default_factory=dict, + description="Additional metadata (e.g., token usage, model info)", + ) + + +class AgentTerminatedPayload(BaseModel): + """Payload for AGENT_TERMINATED events.""" + + agent_instance_id: UUID = Field(..., description="ID of the agent instance") + termination_reason: str = Field(..., description="Reason for termination") + final_status: str = Field(..., description="Final status at termination") + + +class IssueCreatedPayload(BaseModel): + """Payload for ISSUE_CREATED events.""" + + issue_id: str = Field(..., description="Issue ID (from external tracker)") + title: str = Field(..., description="Issue title") + priority: str | None = Field(default=None, description="Issue priority") + labels: list[str] = Field(default_factory=list, description="Issue labels") + + +class IssueUpdatedPayload(BaseModel): + """Payload for ISSUE_UPDATED events.""" + + issue_id: str = Field(..., description="Issue ID (from external tracker)") + changes: dict = Field(..., description="Dictionary of field changes") + + +class IssueAssignedPayload(BaseModel): + """Payload for ISSUE_ASSIGNED events.""" + + issue_id: str = Field(..., description="Issue ID (from external tracker)") + assignee_id: UUID | None = Field( + default=None, description="Agent or user assigned to" + ) + assignee_name: str | None = Field(default=None, description="Assignee name") + + +class IssueClosedPayload(BaseModel): + """Payload for ISSUE_CLOSED events.""" + + issue_id: str = Field(..., description="Issue ID (from external tracker)") + resolution: str = Field(..., description="Resolution status") + + +class SprintStartedPayload(BaseModel): + """Payload for SPRINT_STARTED events.""" + + sprint_id: UUID = Field(..., description="Sprint ID") + sprint_name: str = Field(..., description="Sprint name") + goal: str | None = Field(default=None, description="Sprint goal") + issue_count: int = Field(default=0, description="Number of issues in sprint") + + +class SprintCompletedPayload(BaseModel): + """Payload for SPRINT_COMPLETED events.""" + + sprint_id: UUID = Field(..., description="Sprint ID") + sprint_name: str = Field(..., description="Sprint name") + completed_issues: int = Field(default=0, description="Number of completed issues") + incomplete_issues: int = Field( + default=0, description="Number of incomplete issues" + ) + + +class ApprovalRequestedPayload(BaseModel): + """Payload for APPROVAL_REQUESTED events.""" + + approval_id: UUID = Field(..., description="Approval request ID") + approval_type: str = Field(..., description="Type of approval needed") + description: str = Field(..., description="Description of what needs approval") + requested_by: UUID | None = Field( + default=None, description="Agent/user requesting approval" + ) + timeout_minutes: int | None = Field( + default=None, description="Minutes before auto-escalation" + ) + + +class ApprovalGrantedPayload(BaseModel): + """Payload for APPROVAL_GRANTED events.""" + + approval_id: UUID = Field(..., description="Approval request ID") + approved_by: UUID = Field(..., description="User who granted approval") + comments: str | None = Field(default=None, description="Approval comments") + + +class ApprovalDeniedPayload(BaseModel): + """Payload for APPROVAL_DENIED events.""" + + approval_id: UUID = Field(..., description="Approval request ID") + denied_by: UUID = Field(..., description="User who denied approval") + reason: str = Field(..., description="Reason for denial") + + +class WorkflowStartedPayload(BaseModel): + """Payload for WORKFLOW_STARTED events.""" + + workflow_id: UUID = Field(..., description="Workflow execution ID") + workflow_type: str = Field(..., description="Type of workflow") + total_steps: int = Field(default=0, description="Total number of steps") + + +class WorkflowStepCompletedPayload(BaseModel): + """Payload for WORKFLOW_STEP_COMPLETED events.""" + + workflow_id: UUID = Field(..., description="Workflow execution ID") + step_name: str = Field(..., description="Name of completed step") + step_number: int = Field(..., description="Step number (1-indexed)") + total_steps: int = Field(..., description="Total number of steps") + result: dict = Field(default_factory=dict, description="Step result data") + + +class WorkflowCompletedPayload(BaseModel): + """Payload for WORKFLOW_COMPLETED events.""" + + workflow_id: UUID = Field(..., description="Workflow execution ID") + duration_seconds: float = Field(..., description="Total execution duration") + result: dict = Field(default_factory=dict, description="Workflow result data") + + +class WorkflowFailedPayload(BaseModel): + """Payload for WORKFLOW_FAILED events.""" + + workflow_id: UUID = Field(..., description="Workflow execution ID") + error_message: str = Field(..., description="Error message") + failed_step: str | None = Field(default=None, description="Step that failed") + recoverable: bool = Field(default=False, description="Whether error is recoverable") diff --git a/backend/app/services/event_bus.py b/backend/app/services/event_bus.py new file mode 100644 index 0000000..65fdec6 --- /dev/null +++ b/backend/app/services/event_bus.py @@ -0,0 +1,622 @@ +""" +EventBus service for Redis Pub/Sub communication. + +This module provides a centralized event bus for publishing and subscribing to +events across the Syndarix platform. It uses Redis Pub/Sub for real-time +message delivery between services, agents, and the frontend. + +Architecture: +- Publishers emit events to project/agent-specific Redis channels +- SSE endpoints subscribe to channels and stream events to clients +- Events include metadata for reconnection support (Last-Event-ID) +- Events are typed with the EventType enum for consistency + +Usage: + # Publishing events + event_bus = EventBus() + await event_bus.connect() + + event = event_bus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="agent", + payload={"message": "Processing..."} + ) + await event_bus.publish(event_bus.get_project_channel(project_id), event) + + # Subscribing to events + async for event in event_bus.subscribe(["project:123", "agent:456"]): + handle_event(event) + + # Cleanup + await event_bus.disconnect() +""" + +import asyncio +import json +import logging +from collections.abc import AsyncGenerator, AsyncIterator +from contextlib import asynccontextmanager +from datetime import UTC, datetime +from typing import Any +from uuid import UUID, uuid4 + +import redis.asyncio as redis +from pydantic import ValidationError + +from app.core.config import settings +from app.schemas.events import ActorType, Event, EventType + +logger = logging.getLogger(__name__) + + +class EventBusError(Exception): + """Base exception for EventBus errors.""" + + + +class EventBusConnectionError(EventBusError): + """Raised when connection to Redis fails.""" + + + +class EventBusPublishError(EventBusError): + """Raised when publishing an event fails.""" + + + +class EventBusSubscriptionError(EventBusError): + """Raised when subscribing to channels fails.""" + + + +class EventBus: + """ + EventBus for Redis Pub/Sub communication. + + Provides methods to publish events to channels and subscribe to events + from multiple channels. Handles connection management, serialization, + and error recovery. + + This class provides: + - Event publishing to project/agent-specific channels + - Subscription management for SSE endpoints + - Reconnection support via event IDs and sequence numbers + - Keepalive messages for connection health + - Type-safe event creation with the Event schema + + Attributes: + redis_url: Redis connection URL + redis_client: Async Redis client instance + pubsub: Redis PubSub instance for subscriptions + """ + + # Channel prefixes for different entity types + PROJECT_CHANNEL_PREFIX = "project" + AGENT_CHANNEL_PREFIX = "agent" + USER_CHANNEL_PREFIX = "user" + GLOBAL_CHANNEL = "syndarix:global" + + def __init__(self, redis_url: str | None = None) -> None: + """ + Initialize the EventBus. + + Args: + redis_url: Redis connection URL. Defaults to settings.REDIS_URL. + """ + self.redis_url = redis_url or settings.REDIS_URL + self._redis_client: redis.Redis | None = None + self._pubsub: redis.client.PubSub | None = None + self._connected = False + self._sequence_counters: dict[str, int] = {} + + @property + def redis_client(self) -> redis.Redis: + """Get the Redis client, raising if not connected.""" + if self._redis_client is None: + raise EventBusConnectionError( + "EventBus not connected. Call connect() first." + ) + return self._redis_client + + @property + def pubsub(self) -> redis.client.PubSub: + """Get the PubSub instance, raising if not connected.""" + if self._pubsub is None: + raise EventBusConnectionError( + "EventBus not connected. Call connect() first." + ) + return self._pubsub + + @property + def is_connected(self) -> bool: + """Check if the EventBus is connected to Redis.""" + return self._connected and self._redis_client is not None + + async def connect(self) -> None: + """ + Connect to Redis and initialize the PubSub client. + + Raises: + EventBusConnectionError: If connection to Redis fails. + """ + if self._connected: + logger.debug("EventBus already connected") + return + + try: + self._redis_client = redis.from_url( + self.redis_url, + encoding="utf-8", + decode_responses=True, + ) + # Test connection - ping() returns a coroutine for async Redis + ping_result = self._redis_client.ping() + if hasattr(ping_result, "__await__"): + await ping_result + self._pubsub = self._redis_client.pubsub() + self._connected = True + logger.info("EventBus connected to Redis") + except redis.ConnectionError as e: + logger.error(f"Failed to connect to Redis: {e}", exc_info=True) + raise EventBusConnectionError(f"Failed to connect to Redis: {e}") from e + except redis.RedisError as e: + logger.error(f"Redis error during connection: {e}", exc_info=True) + raise EventBusConnectionError(f"Redis error: {e}") from e + + async def disconnect(self) -> None: + """ + Disconnect from Redis and cleanup resources. + """ + if self._pubsub: + try: + await self._pubsub.unsubscribe() + await self._pubsub.close() + except redis.RedisError as e: + logger.warning(f"Error closing PubSub: {e}") + finally: + self._pubsub = None + + if self._redis_client: + try: + await self._redis_client.aclose() + except redis.RedisError as e: + logger.warning(f"Error closing Redis client: {e}") + finally: + self._redis_client = None + + self._connected = False + logger.info("EventBus disconnected from Redis") + + @asynccontextmanager + async def connection(self) -> AsyncIterator["EventBus"]: + """ + Context manager for automatic connection handling. + + Usage: + async with event_bus.connection() as bus: + await bus.publish(channel, event) + """ + await self.connect() + try: + yield self + finally: + await self.disconnect() + + def get_project_channel(self, project_id: UUID | str) -> str: + """ + Get the channel name for a project. + + Args: + project_id: The project UUID or string + + Returns: + Channel name string in format "project:{uuid}" + """ + return f"{self.PROJECT_CHANNEL_PREFIX}:{project_id}" + + def get_agent_channel(self, agent_id: UUID | str) -> str: + """ + Get the channel name for an agent instance. + + Args: + agent_id: The agent instance UUID or string + + Returns: + Channel name string in format "agent:{uuid}" + """ + return f"{self.AGENT_CHANNEL_PREFIX}:{agent_id}" + + def get_user_channel(self, user_id: UUID | str) -> str: + """ + Get the channel name for a user (personal notifications). + + Args: + user_id: The user UUID or string + + Returns: + Channel name string in format "user:{uuid}" + """ + return f"{self.USER_CHANNEL_PREFIX}:{user_id}" + + def _get_next_sequence(self, channel: str) -> int: + """Get the next sequence number for a channel's events.""" + current = self._sequence_counters.get(channel, 0) + self._sequence_counters[channel] = current + 1 + return current + 1 + + @staticmethod + def create_event( + event_type: EventType, + project_id: UUID, + actor_type: ActorType, + payload: dict | None = None, + actor_id: UUID | None = None, + event_id: str | None = None, + timestamp: datetime | None = None, + ) -> Event: + """ + Factory method to create a new Event. + + Args: + event_type: The type of event + project_id: The project this event belongs to + actor_type: Type of actor ('agent', 'user', or 'system') + payload: Event-specific payload data + actor_id: ID of the agent or user who triggered the event + event_id: Optional custom event ID (UUID string) + timestamp: Optional custom timestamp (defaults to now UTC) + + Returns: + A new Event instance + """ + return Event( + id=event_id or str(uuid4()), + type=event_type, + timestamp=timestamp or datetime.now(UTC), + project_id=project_id, + actor_id=actor_id, + actor_type=actor_type, + payload=payload or {}, + ) + + def _serialize_event(self, event: Event) -> str: + """ + Serialize an event to JSON string. + + Args: + event: The Event to serialize + + Returns: + JSON string representation of the event + """ + return event.model_dump_json() + + def _deserialize_event(self, data: str) -> Event: + """ + Deserialize a JSON string to an Event. + + Args: + data: JSON string to deserialize + + Returns: + Deserialized Event instance + + Raises: + ValidationError: If the data doesn't match the Event schema + """ + return Event.model_validate_json(data) + + async def publish(self, channel: str, event: Event) -> int: + """ + Publish an event to a channel. + + Args: + channel: The channel name to publish to + event: The Event to publish + + Returns: + Number of subscribers that received the message + + Raises: + EventBusConnectionError: If not connected to Redis + EventBusPublishError: If publishing fails + """ + if not self.is_connected: + raise EventBusConnectionError("EventBus not connected") + + try: + message = self._serialize_event(event) + subscriber_count = await self.redis_client.publish(channel, message) + logger.debug( + f"Published event {event.type} to {channel} " + f"(received by {subscriber_count} subscribers)" + ) + return subscriber_count + except redis.RedisError as e: + logger.error(f"Failed to publish event to {channel}: {e}", exc_info=True) + raise EventBusPublishError(f"Failed to publish event: {e}") from e + + async def publish_to_project(self, event: Event) -> int: + """ + Publish an event to the project's channel. + + Convenience method that publishes to the project channel based on + the event's project_id. + + Args: + event: The Event to publish (must have project_id set) + + Returns: + Number of subscribers that received the message + """ + channel = self.get_project_channel(event.project_id) + return await self.publish(channel, event) + + async def publish_multi(self, channels: list[str], event: Event) -> dict[str, int]: + """ + Publish an event to multiple channels. + + Args: + channels: List of channel names to publish to + event: The Event to publish + + Returns: + Dictionary mapping channel names to subscriber counts + """ + results = {} + for channel in channels: + try: + results[channel] = await self.publish(channel, event) + except EventBusPublishError as e: + logger.warning(f"Failed to publish to {channel}: {e}") + results[channel] = 0 + return results + + async def subscribe( + self, channels: list[str], *, max_wait: float | None = None + ) -> AsyncIterator[Event]: + """ + Subscribe to one or more channels and yield events. + + This is an async generator that yields Event objects as they arrive. + Use max_wait to limit how long to wait for messages. + + Args: + channels: List of channel names to subscribe to + max_wait: Optional maximum wait time in seconds for each message. + If None, waits indefinitely. + + Yields: + Event objects received from subscribed channels + + Raises: + EventBusConnectionError: If not connected to Redis + EventBusSubscriptionError: If subscription fails + + Example: + async for event in event_bus.subscribe(["project:123"], max_wait=30): + print(f"Received: {event.type}") + """ + if not self.is_connected: + raise EventBusConnectionError("EventBus not connected") + + # Create a new pubsub for this subscription + subscription_pubsub = self.redis_client.pubsub() + + try: + await subscription_pubsub.subscribe(*channels) + logger.info(f"Subscribed to channels: {channels}") + except redis.RedisError as e: + logger.error(f"Failed to subscribe to channels: {e}", exc_info=True) + await subscription_pubsub.close() + raise EventBusSubscriptionError(f"Failed to subscribe: {e}") from e + + try: + while True: + try: + if max_wait is not None: + async with asyncio.timeout(max_wait): + message = await subscription_pubsub.get_message( + ignore_subscribe_messages=True, timeout=1.0 + ) + else: + message = await subscription_pubsub.get_message( + ignore_subscribe_messages=True, timeout=1.0 + ) + except TimeoutError: + # Timeout reached, stop iteration + return + + if message is None: + continue + + if message["type"] == "message": + try: + event = self._deserialize_event(message["data"]) + yield event + except ValidationError as e: + logger.warning( + f"Invalid event data received: {e}", + extra={"channel": message.get("channel")}, + ) + continue + except json.JSONDecodeError as e: + logger.warning( + f"Failed to decode event JSON: {e}", + extra={"channel": message.get("channel")}, + ) + continue + finally: + try: + await subscription_pubsub.unsubscribe(*channels) + await subscription_pubsub.close() + logger.debug(f"Unsubscribed from channels: {channels}") + except redis.RedisError as e: + logger.warning(f"Error unsubscribing from channels: {e}") + + async def subscribe_sse( + self, + project_id: str | UUID, + last_event_id: str | None = None, + keepalive_interval: int = 30, + ) -> AsyncGenerator[str, None]: + """ + Subscribe to events for a project in SSE format. + + This is an async generator that yields SSE-formatted event strings. + It includes keepalive messages at the specified interval. + + Args: + project_id: The project to subscribe to + last_event_id: Optional last received event ID for reconnection + keepalive_interval: Seconds between keepalive messages (default 30) + + Yields: + SSE-formatted event strings (ready to send to client) + """ + if not self.is_connected: + raise EventBusConnectionError("EventBus not connected") + + project_id_str = str(project_id) + channel = self.get_project_channel(project_id_str) + + subscription_pubsub = self.redis_client.pubsub() + await subscription_pubsub.subscribe(channel) + + logger.info( + f"Subscribed to SSE events for project {project_id_str} " + f"(last_event_id={last_event_id})" + ) + + try: + while True: + try: + # Wait for messages with a timeout for keepalive + message = await asyncio.wait_for( + subscription_pubsub.get_message(ignore_subscribe_messages=True), + timeout=keepalive_interval, + ) + + if message is not None and message["type"] == "message": + event_data = message["data"] + + # If reconnecting, check if we should skip this event + if last_event_id: + try: + event_dict = json.loads(event_data) + if event_dict.get("id") == last_event_id: + # Found the last event, start yielding from next + last_event_id = None + continue + except json.JSONDecodeError: + pass + + yield event_data + + except TimeoutError: + # Send keepalive comment + yield "" # Empty string signals keepalive + + except asyncio.CancelledError: + logger.info(f"SSE subscription cancelled for project {project_id_str}") + raise + finally: + await subscription_pubsub.unsubscribe(channel) + await subscription_pubsub.close() + logger.info(f"Unsubscribed SSE from project {project_id_str}") + + async def subscribe_with_callback( + self, + channels: list[str], + callback: Any, # Callable[[Event], Awaitable[None]] + stop_event: asyncio.Event | None = None, + ) -> None: + """ + Subscribe to channels and process events with a callback. + + This method runs until stop_event is set or an unrecoverable error occurs. + + Args: + channels: List of channel names to subscribe to + callback: Async function to call for each event + stop_event: Optional asyncio.Event to signal stop + + Example: + async def handle_event(event: Event): + print(f"Handling: {event.type}") + + stop = asyncio.Event() + asyncio.create_task( + event_bus.subscribe_with_callback(["project:123"], handle_event, stop) + ) + # Later... + stop.set() + """ + if stop_event is None: + stop_event = asyncio.Event() + + try: + async for event in self.subscribe(channels): + if stop_event.is_set(): + break + try: + await callback(event) + except Exception as e: + logger.error(f"Error in event callback: {e}", exc_info=True) + except EventBusSubscriptionError: + raise + except Exception as e: + logger.error(f"Unexpected error in subscription loop: {e}", exc_info=True) + raise + + +# Singleton instance for application-wide use +_event_bus: EventBus | None = None + + +def get_event_bus() -> EventBus: + """ + Get the singleton EventBus instance. + + Creates a new instance if one doesn't exist. Note that you still need + to call connect() before using the EventBus. + + Returns: + The singleton EventBus instance + """ + global _event_bus + if _event_bus is None: + _event_bus = EventBus() + return _event_bus + + +async def get_connected_event_bus() -> EventBus: + """ + Get a connected EventBus instance. + + Ensures the EventBus is connected before returning. For use in + FastAPI dependency injection. + + Returns: + A connected EventBus instance + + Raises: + EventBusConnectionError: If connection fails + """ + event_bus = get_event_bus() + if not event_bus.is_connected: + await event_bus.connect() + return event_bus + + +async def close_event_bus() -> None: + """ + Close the global EventBus instance. + + Should be called during application shutdown. + """ + global _event_bus + if _event_bus is not None: + await _event_bus.disconnect() + _event_bus = None diff --git a/backend/tests/services/test_event_bus.py b/backend/tests/services/test_event_bus.py new file mode 100644 index 0000000..9b35525 --- /dev/null +++ b/backend/tests/services/test_event_bus.py @@ -0,0 +1,1035 @@ +# tests/services/test_event_bus.py +""" +Tests for the EventBus service. + +These tests verify: +- Event creation and serialization +- Publishing events to channels +- Subscribing to channels and receiving events +- Channel isolation (events only go to intended channels) +- Error handling for connection failures +- Multiple subscriptions and concurrent publishing +""" + +import asyncio +import json +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +from app.schemas.events import ( + AgentMessagePayload, + AgentSpawnedPayload, + Event, + EventType, +) +from app.services.event_bus import ( + EventBus, + EventBusConnectionError, + EventBusPublishError, + close_event_bus, + get_connected_event_bus, + get_event_bus, +) + + +class TestEventSchema: + """Tests for Event schema and payload schemas.""" + + def test_create_event_with_all_fields(self): + """Test creating an Event with all fields populated.""" + project_id = uuid4() + actor_id = uuid4() + timestamp = datetime.now(UTC) + + event = Event( + id="test-event-id", + type=EventType.AGENT_MESSAGE, + timestamp=timestamp, + project_id=project_id, + actor_id=actor_id, + actor_type="agent", + payload={"message": "Hello, world!"}, + ) + + assert event.id == "test-event-id" + assert event.type == EventType.AGENT_MESSAGE + assert event.timestamp == timestamp + assert event.project_id == project_id + assert event.actor_id == actor_id + assert event.actor_type == "agent" + assert event.payload == {"message": "Hello, world!"} + + def test_create_event_with_minimal_fields(self): + """Test creating an Event with only required fields.""" + project_id = uuid4() + + event = Event( + id="test-event-id", + type=EventType.ISSUE_CREATED, + timestamp=datetime.now(UTC), + project_id=project_id, + actor_type="system", + ) + + assert event.id == "test-event-id" + assert event.actor_id is None + assert event.payload == {} + + def test_event_serialization_json(self): + """Test Event serializes to JSON correctly.""" + project_id = uuid4() + event = Event( + id="test-event-id", + type=EventType.SPRINT_STARTED, + timestamp=datetime(2024, 1, 15, 10, 30, 0, tzinfo=UTC), + project_id=project_id, + actor_type="user", + payload={"sprint_name": "Sprint 1"}, + ) + + json_str = event.model_dump_json() + data = json.loads(json_str) + + assert data["id"] == "test-event-id" + assert data["type"] == "sprint.started" + assert data["actor_type"] == "user" + assert data["payload"]["sprint_name"] == "Sprint 1" + + def test_event_deserialization_json(self): + """Test Event deserializes from JSON correctly.""" + project_id = uuid4() + json_data = { + "id": "test-event-id", + "type": "agent.spawned", + "timestamp": "2024-01-15T10:30:00Z", + "project_id": str(project_id), + "actor_id": None, + "actor_type": "system", + "payload": {"agent_name": "PO Agent"}, + } + + event = Event.model_validate(json_data) + + assert event.id == "test-event-id" + assert event.type == EventType.AGENT_SPAWNED + assert event.project_id == project_id + assert event.payload["agent_name"] == "PO Agent" + + def test_event_type_enum_values(self): + """Test all EventType enum values are accessible.""" + # Agent events + assert EventType.AGENT_SPAWNED.value == "agent.spawned" + assert EventType.AGENT_STATUS_CHANGED.value == "agent.status_changed" + assert EventType.AGENT_MESSAGE.value == "agent.message" + assert EventType.AGENT_TERMINATED.value == "agent.terminated" + + # Issue events + assert EventType.ISSUE_CREATED.value == "issue.created" + assert EventType.ISSUE_UPDATED.value == "issue.updated" + assert EventType.ISSUE_ASSIGNED.value == "issue.assigned" + assert EventType.ISSUE_CLOSED.value == "issue.closed" + + # Sprint events + assert EventType.SPRINT_STARTED.value == "sprint.started" + assert EventType.SPRINT_COMPLETED.value == "sprint.completed" + + # Approval events + assert EventType.APPROVAL_REQUESTED.value == "approval.requested" + assert EventType.APPROVAL_GRANTED.value == "approval.granted" + assert EventType.APPROVAL_DENIED.value == "approval.denied" + + def test_actor_type_literal(self): + """Test ActorType accepts valid values.""" + project_id = uuid4() + + # Valid actor types + for actor_type in ["agent", "user", "system"]: + event = Event( + id="test", + type=EventType.AGENT_MESSAGE, + timestamp=datetime.now(UTC), + project_id=project_id, + actor_type=actor_type, + ) + assert event.actor_type == actor_type + + def test_invalid_actor_type_rejected(self): + """Test invalid actor type is rejected.""" + project_id = uuid4() + + with pytest.raises(ValidationError): + Event( + id="test", + type=EventType.AGENT_MESSAGE, + timestamp=datetime.now(UTC), + project_id=project_id, + actor_type="invalid", # Invalid actor type + ) + + +class TestAgentPayloadSchemas: + """Tests for agent-specific payload schemas.""" + + def test_agent_spawned_payload(self): + """Test AgentSpawnedPayload schema.""" + payload = AgentSpawnedPayload( + agent_instance_id=uuid4(), + agent_type_id=uuid4(), + agent_name="Product Owner Agent", + role="product_owner", + ) + + assert payload.agent_name == "Product Owner Agent" + assert payload.role == "product_owner" + + def test_agent_message_payload(self): + """Test AgentMessagePayload schema.""" + payload = AgentMessagePayload( + agent_instance_id=uuid4(), + message="Processing requirements...", + message_type="info", + metadata={"tokens_used": 150}, + ) + + assert payload.message == "Processing requirements..." + assert payload.message_type == "info" + assert payload.metadata["tokens_used"] == 150 + + def test_agent_message_payload_defaults(self): + """Test AgentMessagePayload has correct defaults.""" + payload = AgentMessagePayload( + agent_instance_id=uuid4(), + message="Test message", + ) + + assert payload.message_type == "info" + assert payload.metadata == {} + + +class TestEventBusChannels: + """Tests for EventBus channel helper methods.""" + + def test_get_project_channel_with_uuid(self): + """Test get_project_channel with UUID.""" + event_bus = EventBus() + project_id = uuid4() + + channel = event_bus.get_project_channel(project_id) + + assert channel == f"project:{project_id}" + + def test_get_project_channel_with_string(self): + """Test get_project_channel with string.""" + event_bus = EventBus() + + channel = event_bus.get_project_channel("test-project-123") + + assert channel == "project:test-project-123" + + def test_get_agent_channel_with_uuid(self): + """Test get_agent_channel with UUID.""" + event_bus = EventBus() + agent_id = uuid4() + + channel = event_bus.get_agent_channel(agent_id) + + assert channel == f"agent:{agent_id}" + + def test_get_user_channel_with_uuid(self): + """Test get_user_channel with UUID.""" + event_bus = EventBus() + user_id = uuid4() + + channel = event_bus.get_user_channel(user_id) + + assert channel == f"user:{user_id}" + + def test_channel_prefixes(self): + """Test channel prefix constants.""" + assert EventBus.PROJECT_CHANNEL_PREFIX == "project" + assert EventBus.AGENT_CHANNEL_PREFIX == "agent" + assert EventBus.USER_CHANNEL_PREFIX == "user" + assert EventBus.GLOBAL_CHANNEL == "syndarix:global" + + +class TestEventBusCreateEvent: + """Tests for EventBus.create_event factory method.""" + + def test_create_event_with_required_fields(self): + """Test create_event with only required fields.""" + project_id = uuid4() + + event = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="agent", + ) + + assert event.type == EventType.AGENT_MESSAGE + assert event.project_id == project_id + assert event.actor_type == "agent" + assert event.actor_id is None + assert event.payload == {} + # Auto-generated fields + assert event.id is not None + assert event.timestamp is not None + + def test_create_event_with_all_fields(self): + """Test create_event with all fields.""" + project_id = uuid4() + actor_id = uuid4() + timestamp = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC) + + event = EventBus.create_event( + event_type=EventType.APPROVAL_REQUESTED, + project_id=project_id, + actor_type="agent", + payload={"approval_type": "code_review"}, + actor_id=actor_id, + event_id="custom-event-id", + timestamp=timestamp, + ) + + assert event.id == "custom-event-id" + assert event.type == EventType.APPROVAL_REQUESTED + assert event.timestamp == timestamp + assert event.project_id == project_id + assert event.actor_id == actor_id + assert event.actor_type == "agent" + assert event.payload == {"approval_type": "code_review"} + + def test_create_event_generates_unique_ids(self): + """Test create_event generates unique IDs.""" + project_id = uuid4() + + event1 = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="system", + ) + event2 = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="system", + ) + + assert event1.id != event2.id + + +class TestEventBusConnection: + """Tests for EventBus connection management.""" + + def test_initial_state(self): + """Test EventBus starts disconnected.""" + event_bus = EventBus() + + assert not event_bus.is_connected + assert event_bus._redis_client is None + assert event_bus._pubsub is None + + def test_redis_client_raises_when_disconnected(self): + """Test accessing redis_client raises when not connected.""" + event_bus = EventBus() + + with pytest.raises(EventBusConnectionError) as exc_info: + _ = event_bus.redis_client + + assert "not connected" in str(exc_info.value).lower() + + def test_pubsub_raises_when_disconnected(self): + """Test accessing pubsub raises when not connected.""" + event_bus = EventBus() + + with pytest.raises(EventBusConnectionError) as exc_info: + _ = event_bus.pubsub + + assert "not connected" in str(exc_info.value).lower() + + @pytest.mark.asyncio + async def test_connect_success(self): + """Test successful Redis connection.""" + event_bus = EventBus() + + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(return_value=True) + mock_redis.pubsub = MagicMock(return_value=AsyncMock()) + + with patch("redis.asyncio.from_url", return_value=mock_redis): + await event_bus.connect() + + assert event_bus.is_connected + mock_redis.ping.assert_called_once() + + @pytest.mark.asyncio + async def test_connect_already_connected(self): + """Test connect when already connected is idempotent.""" + event_bus = EventBus() + event_bus._connected = True + event_bus._redis_client = AsyncMock() + + # Should not raise or create new connection + await event_bus.connect() + + assert event_bus.is_connected + + @pytest.mark.asyncio + async def test_connect_failure(self): + """Test connection failure raises appropriate error.""" + import redis.asyncio as redis_async + + event_bus = EventBus() + + with patch( + "redis.asyncio.from_url", + side_effect=redis_async.ConnectionError("Connection refused"), + ): + with pytest.raises(EventBusConnectionError) as exc_info: + await event_bus.connect() + + assert "Connection refused" in str(exc_info.value) + assert not event_bus.is_connected + + @pytest.mark.asyncio + async def test_disconnect(self): + """Test disconnection cleans up resources.""" + event_bus = EventBus() + event_bus._connected = True + event_bus._redis_client = AsyncMock() + event_bus._redis_client.aclose = AsyncMock() + event_bus._pubsub = AsyncMock() + event_bus._pubsub.unsubscribe = AsyncMock() + event_bus._pubsub.close = AsyncMock() + + await event_bus.disconnect() + + assert not event_bus.is_connected + assert event_bus._redis_client is None + assert event_bus._pubsub is None + + @pytest.mark.asyncio + async def test_connection_context_manager(self): + """Test connection context manager.""" + event_bus = EventBus() + + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(return_value=True) + mock_redis.pubsub = MagicMock(return_value=AsyncMock()) + mock_redis.aclose = AsyncMock() + + with patch("redis.asyncio.from_url", return_value=mock_redis): + async with event_bus.connection() as bus: + assert bus.is_connected + + # After context, should be disconnected + assert not event_bus.is_connected + + +class TestEventBusPublish: + """Tests for EventBus publish functionality.""" + + @pytest.mark.asyncio + async def test_publish_requires_connection(self): + """Test publish raises when not connected.""" + event_bus = EventBus() + project_id = uuid4() + event = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="system", + ) + + with pytest.raises(EventBusConnectionError): + await event_bus.publish("test:channel", event) + + @pytest.mark.asyncio + async def test_publish_success(self): + """Test successful event publishing.""" + event_bus = EventBus() + event_bus._connected = True + + mock_redis = AsyncMock() + mock_redis.publish = AsyncMock(return_value=2) # 2 subscribers + event_bus._redis_client = mock_redis + + project_id = uuid4() + event = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="agent", + payload={"message": "Test message"}, + ) + + subscriber_count = await event_bus.publish("project:test", event) + + assert subscriber_count == 2 + mock_redis.publish.assert_called_once() + + # Verify the published message is valid JSON + call_args = mock_redis.publish.call_args + published_data = call_args[0][1] + parsed = json.loads(published_data) + assert parsed["type"] == "agent.message" + assert parsed["payload"]["message"] == "Test message" + + @pytest.mark.asyncio + async def test_publish_to_project(self): + """Test publish_to_project convenience method.""" + event_bus = EventBus() + event_bus._connected = True + + mock_redis = AsyncMock() + mock_redis.publish = AsyncMock(return_value=1) + event_bus._redis_client = mock_redis + + project_id = uuid4() + event = EventBus.create_event( + event_type=EventType.ISSUE_CREATED, + project_id=project_id, + actor_type="user", + ) + + await event_bus.publish_to_project(event) + + # Verify published to correct channel + call_args = mock_redis.publish.call_args + channel = call_args[0][0] + assert channel == f"project:{project_id}" + + @pytest.mark.asyncio + async def test_publish_multi(self): + """Test publishing to multiple channels.""" + event_bus = EventBus() + event_bus._connected = True + + mock_redis = AsyncMock() + mock_redis.publish = AsyncMock(return_value=1) + event_bus._redis_client = mock_redis + + project_id = uuid4() + event = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="agent", + ) + + channels = ["project:1", "agent:2", "user:3"] + results = await event_bus.publish_multi(channels, event) + + assert len(results) == 3 + assert all(count == 1 for count in results.values()) + assert mock_redis.publish.call_count == 3 + + @pytest.mark.asyncio + async def test_publish_redis_error(self): + """Test publish handles Redis errors.""" + import redis.asyncio as redis_async + + event_bus = EventBus() + event_bus._connected = True + + mock_redis = AsyncMock() + mock_redis.publish = AsyncMock( + side_effect=redis_async.RedisError("Publish failed") + ) + event_bus._redis_client = mock_redis + + project_id = uuid4() + event = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="system", + ) + + with pytest.raises(EventBusPublishError) as exc_info: + await event_bus.publish("test:channel", event) + + assert "Publish failed" in str(exc_info.value) + + +class TestEventBusSubscribe: + """Tests for EventBus subscribe functionality.""" + + @pytest.mark.asyncio + async def test_subscribe_requires_connection(self): + """Test subscribe raises when not connected.""" + event_bus = EventBus() + + with pytest.raises(EventBusConnectionError): + async for _ in event_bus.subscribe(["test:channel"]): + pass + + @pytest.mark.asyncio + async def test_subscribe_receives_events(self): + """Test subscribing to channels receives events.""" + event_bus = EventBus() + event_bus._connected = True + + project_id = uuid4() + test_event = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="agent", + payload={"message": "Test"}, + ) + serialized_event = test_event.model_dump_json() + + # Create mock pubsub that returns one message then times out + mock_pubsub = AsyncMock() + message_queue = [ + {"type": "message", "data": serialized_event, "channel": "test:channel"}, + ] + call_count = 0 + + async def get_message_side_effect(**kwargs): + nonlocal call_count + if call_count < len(message_queue): + result = message_queue[call_count] + call_count += 1 + return result + # After messages exhausted, raise TimeoutError to end subscription + raise TimeoutError() + + mock_pubsub.get_message = get_message_side_effect + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.unsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + mock_redis = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + event_bus._redis_client = mock_redis + + received_events = [] + # Use max_wait to ensure test doesn't hang + async for event in event_bus.subscribe(["test:channel"], max_wait=0.5): + received_events.append(event) + + assert len(received_events) == 1 + assert received_events[0].type == EventType.AGENT_MESSAGE + assert received_events[0].payload["message"] == "Test" + + @pytest.mark.asyncio + async def test_subscribe_handles_invalid_json(self): + """Test subscribe skips invalid JSON messages.""" + event_bus = EventBus() + event_bus._connected = True + + # Create mock pubsub with invalid JSON + mock_pubsub = AsyncMock() + message_queue = [ + {"type": "message", "data": "not valid json", "channel": "test"}, + ] + call_count = 0 + + async def get_message_side_effect(**kwargs): + nonlocal call_count + if call_count < len(message_queue): + result = message_queue[call_count] + call_count += 1 + return result + raise TimeoutError() + + mock_pubsub.get_message = get_message_side_effect + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.unsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + mock_redis = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + event_bus._redis_client = mock_redis + + received_events = [] + async for event in event_bus.subscribe(["test:channel"], max_wait=0.5): + received_events.append(event) + + # Should receive no events (invalid JSON was skipped) + assert len(received_events) == 0 + + @pytest.mark.asyncio + async def test_subscribe_handles_invalid_event_schema(self): + """Test subscribe skips messages with invalid Event schema.""" + event_bus = EventBus() + event_bus._connected = True + + # Valid JSON but invalid Event schema (missing required fields) + invalid_event_json = json.dumps({"type": "test", "data": "incomplete"}) + + mock_pubsub = AsyncMock() + message_queue = [ + {"type": "message", "data": invalid_event_json, "channel": "test"}, + ] + call_count = 0 + + async def get_message_side_effect(**kwargs): + nonlocal call_count + if call_count < len(message_queue): + result = message_queue[call_count] + call_count += 1 + return result + raise TimeoutError() + + mock_pubsub.get_message = get_message_side_effect + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.unsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + mock_redis = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + event_bus._redis_client = mock_redis + + received_events = [] + async for event in event_bus.subscribe(["test:channel"], max_wait=0.5): + received_events.append(event) + + # Should receive no events (invalid schema was skipped) + assert len(received_events) == 0 + + @pytest.mark.asyncio + async def test_subscribe_max_wait(self): + """Test subscribe respects max_wait.""" + event_bus = EventBus() + event_bus._connected = True + + mock_pubsub = AsyncMock() + + # Simulate slow get_message that will trigger timeout + async def slow_get_message(**kwargs): + # Sleep longer than max_wait to trigger timeout + await asyncio.sleep(1.0) + return None + + mock_pubsub.get_message = slow_get_message + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.unsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + mock_redis = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + event_bus._redis_client = mock_redis + + received_events = [] + start_time = asyncio.get_running_loop().time() + async for event in event_bus.subscribe(["test:channel"], max_wait=0.2): + received_events.append(event) + elapsed = asyncio.get_running_loop().time() - start_time + + assert len(received_events) == 0 + # Should have timed out around 0.2 seconds + assert elapsed < 1.0 + + +class TestEventBusChannelIsolation: + """Tests for channel isolation in EventBus.""" + + @pytest.mark.asyncio + async def test_events_only_received_on_subscribed_channel(self): + """Test events are only received on subscribed channels.""" + event_bus = EventBus() + event_bus._connected = True + + project1_id = uuid4() + project2_id = uuid4() # noqa: F841 + + # Event for project 1 + event1 = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project1_id, + actor_type="agent", + payload={"message": "For project 1"}, + ) + serialized_event1 = event1.model_dump_json() + + # Mock pubsub only returns event1 for project1 channel + mock_pubsub = AsyncMock() + message_queue = [ + { + "type": "message", + "data": serialized_event1, + "channel": f"project:{project1_id}", + }, + ] + call_count = 0 + + async def get_message_side_effect(**kwargs): + nonlocal call_count + if call_count < len(message_queue): + result = message_queue[call_count] + call_count += 1 + return result + raise TimeoutError() + + mock_pubsub.get_message = get_message_side_effect + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.unsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + mock_redis = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + event_bus._redis_client = mock_redis + + # Subscribe to project1 channel + received = [] + channel = event_bus.get_project_channel(project1_id) + async for event in event_bus.subscribe([channel], max_wait=0.5): + received.append(event) + + assert len(received) == 1 + assert received[0].project_id == project1_id + + def test_channel_names_are_unique(self): + """Test different entity types have unique channel names.""" + event_bus = EventBus() + same_id = uuid4() + + project_channel = event_bus.get_project_channel(same_id) + agent_channel = event_bus.get_agent_channel(same_id) + user_channel = event_bus.get_user_channel(same_id) + + # All channels should be different even with same ID + assert project_channel != agent_channel + assert agent_channel != user_channel + assert project_channel != user_channel + + +class TestEventBusSingleton: + """Tests for EventBus singleton management.""" + + @pytest.mark.asyncio + async def test_get_event_bus_returns_singleton(self): + """Test get_event_bus returns same instance.""" + # Reset singleton + import app.services.event_bus as event_bus_module + + event_bus_module._event_bus = None + + bus1 = get_event_bus() + bus2 = get_event_bus() + + assert bus1 is bus2 + + # Cleanup + event_bus_module._event_bus = None + + @pytest.mark.asyncio + async def test_get_connected_event_bus(self): + """Test get_connected_event_bus returns connected instance.""" + import app.services.event_bus as event_bus_module + + event_bus_module._event_bus = None + + mock_redis = AsyncMock() + mock_redis.ping = AsyncMock(return_value=True) + mock_redis.pubsub = MagicMock(return_value=AsyncMock()) + + with patch("redis.asyncio.from_url", return_value=mock_redis): + bus = await get_connected_event_bus() + assert bus.is_connected + + # Cleanup + await close_event_bus() + + @pytest.mark.asyncio + async def test_close_event_bus(self): + """Test close_event_bus cleans up singleton.""" + import app.services.event_bus as event_bus_module + + # Create a connected instance + event_bus_module._event_bus = EventBus() + event_bus_module._event_bus._connected = True + event_bus_module._event_bus._redis_client = AsyncMock() + event_bus_module._event_bus._redis_client.aclose = AsyncMock() + event_bus_module._event_bus._pubsub = AsyncMock() + event_bus_module._event_bus._pubsub.unsubscribe = AsyncMock() + event_bus_module._event_bus._pubsub.close = AsyncMock() + + await close_event_bus() + + assert event_bus_module._event_bus is None + + +class TestEventBusSubscribeWithCallback: + """Tests for subscribe_with_callback method.""" + + @pytest.mark.asyncio + async def test_subscribe_with_callback_processes_events(self): + """Test subscribe_with_callback calls callback for each event.""" + event_bus = EventBus() + event_bus._connected = True + + project_id = uuid4() + test_event = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="agent", + ) + serialized_event = test_event.model_dump_json() + + mock_pubsub = AsyncMock() + message_queue = [ + {"type": "message", "data": serialized_event, "channel": "test"}, + ] + call_count = 0 + + async def get_message_side_effect(**kwargs): + nonlocal call_count + if call_count < len(message_queue): + result = message_queue[call_count] + call_count += 1 + return result + raise TimeoutError() + + mock_pubsub.get_message = get_message_side_effect + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.unsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + mock_redis = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + event_bus._redis_client = mock_redis + + received_events = [] + + async def callback(event: Event): + received_events.append(event) + + # Create stop event and set it after a delay + stop_event = asyncio.Event() + + async def stop_after_delay(): + await asyncio.sleep(0.3) + stop_event.set() + + # Run subscription with callback in background + task = asyncio.create_task( + event_bus.subscribe_with_callback(["test"], callback, stop_event) + ) + stop_task = asyncio.create_task(stop_after_delay()) + + await asyncio.gather(task, stop_task, return_exceptions=True) + + assert len(received_events) == 1 + assert received_events[0].type == EventType.AGENT_MESSAGE + + @pytest.mark.asyncio + async def test_subscribe_with_callback_handles_callback_errors(self): + """Test subscribe_with_callback continues after callback errors.""" + event_bus = EventBus() + event_bus._connected = True + + project_id = uuid4() + test_event = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="agent", + ) + serialized_event = test_event.model_dump_json() + + mock_pubsub = AsyncMock() + message_queue = [ + {"type": "message", "data": serialized_event, "channel": "test"}, + {"type": "message", "data": serialized_event, "channel": "test"}, + ] + call_count = 0 + + async def get_message_side_effect(**kwargs): + nonlocal call_count + if call_count < len(message_queue): + result = message_queue[call_count] + call_count += 1 + return result + raise TimeoutError() + + mock_pubsub.get_message = get_message_side_effect + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.unsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + mock_redis = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + event_bus._redis_client = mock_redis + + call_count = 0 + + async def failing_callback(event: Event): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("First call fails") + # Second call succeeds + + stop_event = asyncio.Event() + + async def stop_after_delay(): + await asyncio.sleep(0.3) + stop_event.set() + + task = asyncio.create_task( + event_bus.subscribe_with_callback(["test"], failing_callback, stop_event) + ) + stop_task = asyncio.create_task(stop_after_delay()) + + await asyncio.gather(task, stop_task, return_exceptions=True) + + # Should have processed both events despite first callback failing + assert call_count == 2 + + +class TestEventBusSSESubscription: + """Tests for SSE subscription functionality.""" + + @pytest.mark.asyncio + async def test_subscribe_sse_yields_json_strings(self): + """Test subscribe_sse yields JSON string data.""" + event_bus = EventBus() + event_bus._connected = True + + project_id = uuid4() + test_event = EventBus.create_event( + event_type=EventType.AGENT_MESSAGE, + project_id=project_id, + actor_type="agent", + payload={"message": "SSE test"}, + ) + serialized_event = test_event.model_dump_json() + + mock_pubsub = AsyncMock() + message_queue = [ + {"type": "message", "data": serialized_event, "channel": "test"}, + ] + call_count = 0 + + async def get_message_side_effect(**kwargs): + nonlocal call_count + if call_count < len(message_queue): + result = message_queue[call_count] + call_count += 1 + return result + # Raise TimeoutError after messages are exhausted + raise TimeoutError() + + mock_pubsub.get_message = get_message_side_effect + mock_pubsub.subscribe = AsyncMock() + mock_pubsub.unsubscribe = AsyncMock() + mock_pubsub.close = AsyncMock() + + mock_redis = AsyncMock() + mock_redis.pubsub = MagicMock(return_value=mock_pubsub) + event_bus._redis_client = mock_redis + + received = [] + count = 0 + async for data in event_bus.subscribe_sse(project_id, keepalive_interval=0.1): + if data: # Skip keepalive empty strings + received.append(data) + count += 1 + if count >= 2: # Get one message and one keepalive + break + + assert len(received) == 1 + # Should be valid JSON + parsed = json.loads(received[0]) + assert parsed["type"] == "agent.message" + assert parsed["payload"]["message"] == "SSE test"