Files
syndarix/backend/tests/services/test_event_bus.py
Felipe Cardoso 3c24a8c522 feat(backend): Add EventBus service with Redis Pub/Sub
- Add EventBus class for real-time event communication
- Add Event schema with type-safe event types (agent, issue, sprint events)
- Add typed payload schemas (AgentSpawnedPayload, AgentMessagePayload)
- Add channel helpers for project/agent/user scoping
- Add subscribe_sse generator for SSE streaming
- Add reconnection support via Last-Event-ID
- Add keepalive mechanism for connection health
- Add 44 comprehensive tests with mocked Redis

Implements #33

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-30 02:07:51 +01:00

1036 lines
34 KiB
Python

# tests/services/test_event_bus.py
"""
Tests for the EventBus service.
These tests verify:
- Event creation and serialization
- Publishing events to channels
- Subscribing to channels and receiving events
- Channel isolation (events only go to intended channels)
- Error handling for connection failures
- Multiple subscriptions and concurrent publishing
"""
import asyncio
import json
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from pydantic import ValidationError
from app.schemas.events import (
AgentMessagePayload,
AgentSpawnedPayload,
Event,
EventType,
)
from app.services.event_bus import (
EventBus,
EventBusConnectionError,
EventBusPublishError,
close_event_bus,
get_connected_event_bus,
get_event_bus,
)
class TestEventSchema:
"""Tests for Event schema and payload schemas."""
def test_create_event_with_all_fields(self):
"""Test creating an Event with all fields populated."""
project_id = uuid4()
actor_id = uuid4()
timestamp = datetime.now(UTC)
event = Event(
id="test-event-id",
type=EventType.AGENT_MESSAGE,
timestamp=timestamp,
project_id=project_id,
actor_id=actor_id,
actor_type="agent",
payload={"message": "Hello, world!"},
)
assert event.id == "test-event-id"
assert event.type == EventType.AGENT_MESSAGE
assert event.timestamp == timestamp
assert event.project_id == project_id
assert event.actor_id == actor_id
assert event.actor_type == "agent"
assert event.payload == {"message": "Hello, world!"}
def test_create_event_with_minimal_fields(self):
"""Test creating an Event with only required fields."""
project_id = uuid4()
event = Event(
id="test-event-id",
type=EventType.ISSUE_CREATED,
timestamp=datetime.now(UTC),
project_id=project_id,
actor_type="system",
)
assert event.id == "test-event-id"
assert event.actor_id is None
assert event.payload == {}
def test_event_serialization_json(self):
"""Test Event serializes to JSON correctly."""
project_id = uuid4()
event = Event(
id="test-event-id",
type=EventType.SPRINT_STARTED,
timestamp=datetime(2024, 1, 15, 10, 30, 0, tzinfo=UTC),
project_id=project_id,
actor_type="user",
payload={"sprint_name": "Sprint 1"},
)
json_str = event.model_dump_json()
data = json.loads(json_str)
assert data["id"] == "test-event-id"
assert data["type"] == "sprint.started"
assert data["actor_type"] == "user"
assert data["payload"]["sprint_name"] == "Sprint 1"
def test_event_deserialization_json(self):
"""Test Event deserializes from JSON correctly."""
project_id = uuid4()
json_data = {
"id": "test-event-id",
"type": "agent.spawned",
"timestamp": "2024-01-15T10:30:00Z",
"project_id": str(project_id),
"actor_id": None,
"actor_type": "system",
"payload": {"agent_name": "PO Agent"},
}
event = Event.model_validate(json_data)
assert event.id == "test-event-id"
assert event.type == EventType.AGENT_SPAWNED
assert event.project_id == project_id
assert event.payload["agent_name"] == "PO Agent"
def test_event_type_enum_values(self):
"""Test all EventType enum values are accessible."""
# Agent events
assert EventType.AGENT_SPAWNED.value == "agent.spawned"
assert EventType.AGENT_STATUS_CHANGED.value == "agent.status_changed"
assert EventType.AGENT_MESSAGE.value == "agent.message"
assert EventType.AGENT_TERMINATED.value == "agent.terminated"
# Issue events
assert EventType.ISSUE_CREATED.value == "issue.created"
assert EventType.ISSUE_UPDATED.value == "issue.updated"
assert EventType.ISSUE_ASSIGNED.value == "issue.assigned"
assert EventType.ISSUE_CLOSED.value == "issue.closed"
# Sprint events
assert EventType.SPRINT_STARTED.value == "sprint.started"
assert EventType.SPRINT_COMPLETED.value == "sprint.completed"
# Approval events
assert EventType.APPROVAL_REQUESTED.value == "approval.requested"
assert EventType.APPROVAL_GRANTED.value == "approval.granted"
assert EventType.APPROVAL_DENIED.value == "approval.denied"
def test_actor_type_literal(self):
"""Test ActorType accepts valid values."""
project_id = uuid4()
# Valid actor types
for actor_type in ["agent", "user", "system"]:
event = Event(
id="test",
type=EventType.AGENT_MESSAGE,
timestamp=datetime.now(UTC),
project_id=project_id,
actor_type=actor_type,
)
assert event.actor_type == actor_type
def test_invalid_actor_type_rejected(self):
"""Test invalid actor type is rejected."""
project_id = uuid4()
with pytest.raises(ValidationError):
Event(
id="test",
type=EventType.AGENT_MESSAGE,
timestamp=datetime.now(UTC),
project_id=project_id,
actor_type="invalid", # Invalid actor type
)
class TestAgentPayloadSchemas:
"""Tests for agent-specific payload schemas."""
def test_agent_spawned_payload(self):
"""Test AgentSpawnedPayload schema."""
payload = AgentSpawnedPayload(
agent_instance_id=uuid4(),
agent_type_id=uuid4(),
agent_name="Product Owner Agent",
role="product_owner",
)
assert payload.agent_name == "Product Owner Agent"
assert payload.role == "product_owner"
def test_agent_message_payload(self):
"""Test AgentMessagePayload schema."""
payload = AgentMessagePayload(
agent_instance_id=uuid4(),
message="Processing requirements...",
message_type="info",
metadata={"tokens_used": 150},
)
assert payload.message == "Processing requirements..."
assert payload.message_type == "info"
assert payload.metadata["tokens_used"] == 150
def test_agent_message_payload_defaults(self):
"""Test AgentMessagePayload has correct defaults."""
payload = AgentMessagePayload(
agent_instance_id=uuid4(),
message="Test message",
)
assert payload.message_type == "info"
assert payload.metadata == {}
class TestEventBusChannels:
"""Tests for EventBus channel helper methods."""
def test_get_project_channel_with_uuid(self):
"""Test get_project_channel with UUID."""
event_bus = EventBus()
project_id = uuid4()
channel = event_bus.get_project_channel(project_id)
assert channel == f"project:{project_id}"
def test_get_project_channel_with_string(self):
"""Test get_project_channel with string."""
event_bus = EventBus()
channel = event_bus.get_project_channel("test-project-123")
assert channel == "project:test-project-123"
def test_get_agent_channel_with_uuid(self):
"""Test get_agent_channel with UUID."""
event_bus = EventBus()
agent_id = uuid4()
channel = event_bus.get_agent_channel(agent_id)
assert channel == f"agent:{agent_id}"
def test_get_user_channel_with_uuid(self):
"""Test get_user_channel with UUID."""
event_bus = EventBus()
user_id = uuid4()
channel = event_bus.get_user_channel(user_id)
assert channel == f"user:{user_id}"
def test_channel_prefixes(self):
"""Test channel prefix constants."""
assert EventBus.PROJECT_CHANNEL_PREFIX == "project"
assert EventBus.AGENT_CHANNEL_PREFIX == "agent"
assert EventBus.USER_CHANNEL_PREFIX == "user"
assert EventBus.GLOBAL_CHANNEL == "syndarix:global"
class TestEventBusCreateEvent:
"""Tests for EventBus.create_event factory method."""
def test_create_event_with_required_fields(self):
"""Test create_event with only required fields."""
project_id = uuid4()
event = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="agent",
)
assert event.type == EventType.AGENT_MESSAGE
assert event.project_id == project_id
assert event.actor_type == "agent"
assert event.actor_id is None
assert event.payload == {}
# Auto-generated fields
assert event.id is not None
assert event.timestamp is not None
def test_create_event_with_all_fields(self):
"""Test create_event with all fields."""
project_id = uuid4()
actor_id = uuid4()
timestamp = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
event = EventBus.create_event(
event_type=EventType.APPROVAL_REQUESTED,
project_id=project_id,
actor_type="agent",
payload={"approval_type": "code_review"},
actor_id=actor_id,
event_id="custom-event-id",
timestamp=timestamp,
)
assert event.id == "custom-event-id"
assert event.type == EventType.APPROVAL_REQUESTED
assert event.timestamp == timestamp
assert event.project_id == project_id
assert event.actor_id == actor_id
assert event.actor_type == "agent"
assert event.payload == {"approval_type": "code_review"}
def test_create_event_generates_unique_ids(self):
"""Test create_event generates unique IDs."""
project_id = uuid4()
event1 = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="system",
)
event2 = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="system",
)
assert event1.id != event2.id
class TestEventBusConnection:
"""Tests for EventBus connection management."""
def test_initial_state(self):
"""Test EventBus starts disconnected."""
event_bus = EventBus()
assert not event_bus.is_connected
assert event_bus._redis_client is None
assert event_bus._pubsub is None
def test_redis_client_raises_when_disconnected(self):
"""Test accessing redis_client raises when not connected."""
event_bus = EventBus()
with pytest.raises(EventBusConnectionError) as exc_info:
_ = event_bus.redis_client
assert "not connected" in str(exc_info.value).lower()
def test_pubsub_raises_when_disconnected(self):
"""Test accessing pubsub raises when not connected."""
event_bus = EventBus()
with pytest.raises(EventBusConnectionError) as exc_info:
_ = event_bus.pubsub
assert "not connected" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_connect_success(self):
"""Test successful Redis connection."""
event_bus = EventBus()
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock(return_value=True)
mock_redis.pubsub = MagicMock(return_value=AsyncMock())
with patch("redis.asyncio.from_url", return_value=mock_redis):
await event_bus.connect()
assert event_bus.is_connected
mock_redis.ping.assert_called_once()
@pytest.mark.asyncio
async def test_connect_already_connected(self):
"""Test connect when already connected is idempotent."""
event_bus = EventBus()
event_bus._connected = True
event_bus._redis_client = AsyncMock()
# Should not raise or create new connection
await event_bus.connect()
assert event_bus.is_connected
@pytest.mark.asyncio
async def test_connect_failure(self):
"""Test connection failure raises appropriate error."""
import redis.asyncio as redis_async
event_bus = EventBus()
with patch(
"redis.asyncio.from_url",
side_effect=redis_async.ConnectionError("Connection refused"),
):
with pytest.raises(EventBusConnectionError) as exc_info:
await event_bus.connect()
assert "Connection refused" in str(exc_info.value)
assert not event_bus.is_connected
@pytest.mark.asyncio
async def test_disconnect(self):
"""Test disconnection cleans up resources."""
event_bus = EventBus()
event_bus._connected = True
event_bus._redis_client = AsyncMock()
event_bus._redis_client.aclose = AsyncMock()
event_bus._pubsub = AsyncMock()
event_bus._pubsub.unsubscribe = AsyncMock()
event_bus._pubsub.close = AsyncMock()
await event_bus.disconnect()
assert not event_bus.is_connected
assert event_bus._redis_client is None
assert event_bus._pubsub is None
@pytest.mark.asyncio
async def test_connection_context_manager(self):
"""Test connection context manager."""
event_bus = EventBus()
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock(return_value=True)
mock_redis.pubsub = MagicMock(return_value=AsyncMock())
mock_redis.aclose = AsyncMock()
with patch("redis.asyncio.from_url", return_value=mock_redis):
async with event_bus.connection() as bus:
assert bus.is_connected
# After context, should be disconnected
assert not event_bus.is_connected
class TestEventBusPublish:
"""Tests for EventBus publish functionality."""
@pytest.mark.asyncio
async def test_publish_requires_connection(self):
"""Test publish raises when not connected."""
event_bus = EventBus()
project_id = uuid4()
event = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="system",
)
with pytest.raises(EventBusConnectionError):
await event_bus.publish("test:channel", event)
@pytest.mark.asyncio
async def test_publish_success(self):
"""Test successful event publishing."""
event_bus = EventBus()
event_bus._connected = True
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock(return_value=2) # 2 subscribers
event_bus._redis_client = mock_redis
project_id = uuid4()
event = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="agent",
payload={"message": "Test message"},
)
subscriber_count = await event_bus.publish("project:test", event)
assert subscriber_count == 2
mock_redis.publish.assert_called_once()
# Verify the published message is valid JSON
call_args = mock_redis.publish.call_args
published_data = call_args[0][1]
parsed = json.loads(published_data)
assert parsed["type"] == "agent.message"
assert parsed["payload"]["message"] == "Test message"
@pytest.mark.asyncio
async def test_publish_to_project(self):
"""Test publish_to_project convenience method."""
event_bus = EventBus()
event_bus._connected = True
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock(return_value=1)
event_bus._redis_client = mock_redis
project_id = uuid4()
event = EventBus.create_event(
event_type=EventType.ISSUE_CREATED,
project_id=project_id,
actor_type="user",
)
await event_bus.publish_to_project(event)
# Verify published to correct channel
call_args = mock_redis.publish.call_args
channel = call_args[0][0]
assert channel == f"project:{project_id}"
@pytest.mark.asyncio
async def test_publish_multi(self):
"""Test publishing to multiple channels."""
event_bus = EventBus()
event_bus._connected = True
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock(return_value=1)
event_bus._redis_client = mock_redis
project_id = uuid4()
event = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="agent",
)
channels = ["project:1", "agent:2", "user:3"]
results = await event_bus.publish_multi(channels, event)
assert len(results) == 3
assert all(count == 1 for count in results.values())
assert mock_redis.publish.call_count == 3
@pytest.mark.asyncio
async def test_publish_redis_error(self):
"""Test publish handles Redis errors."""
import redis.asyncio as redis_async
event_bus = EventBus()
event_bus._connected = True
mock_redis = AsyncMock()
mock_redis.publish = AsyncMock(
side_effect=redis_async.RedisError("Publish failed")
)
event_bus._redis_client = mock_redis
project_id = uuid4()
event = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="system",
)
with pytest.raises(EventBusPublishError) as exc_info:
await event_bus.publish("test:channel", event)
assert "Publish failed" in str(exc_info.value)
class TestEventBusSubscribe:
"""Tests for EventBus subscribe functionality."""
@pytest.mark.asyncio
async def test_subscribe_requires_connection(self):
"""Test subscribe raises when not connected."""
event_bus = EventBus()
with pytest.raises(EventBusConnectionError):
async for _ in event_bus.subscribe(["test:channel"]):
pass
@pytest.mark.asyncio
async def test_subscribe_receives_events(self):
"""Test subscribing to channels receives events."""
event_bus = EventBus()
event_bus._connected = True
project_id = uuid4()
test_event = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="agent",
payload={"message": "Test"},
)
serialized_event = test_event.model_dump_json()
# Create mock pubsub that returns one message then times out
mock_pubsub = AsyncMock()
message_queue = [
{"type": "message", "data": serialized_event, "channel": "test:channel"},
]
call_count = 0
async def get_message_side_effect(**kwargs):
nonlocal call_count
if call_count < len(message_queue):
result = message_queue[call_count]
call_count += 1
return result
# After messages exhausted, raise TimeoutError to end subscription
raise TimeoutError()
mock_pubsub.get_message = get_message_side_effect
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.unsubscribe = AsyncMock()
mock_pubsub.close = AsyncMock()
mock_redis = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
event_bus._redis_client = mock_redis
received_events = []
# Use max_wait to ensure test doesn't hang
async for event in event_bus.subscribe(["test:channel"], max_wait=0.5):
received_events.append(event)
assert len(received_events) == 1
assert received_events[0].type == EventType.AGENT_MESSAGE
assert received_events[0].payload["message"] == "Test"
@pytest.mark.asyncio
async def test_subscribe_handles_invalid_json(self):
"""Test subscribe skips invalid JSON messages."""
event_bus = EventBus()
event_bus._connected = True
# Create mock pubsub with invalid JSON
mock_pubsub = AsyncMock()
message_queue = [
{"type": "message", "data": "not valid json", "channel": "test"},
]
call_count = 0
async def get_message_side_effect(**kwargs):
nonlocal call_count
if call_count < len(message_queue):
result = message_queue[call_count]
call_count += 1
return result
raise TimeoutError()
mock_pubsub.get_message = get_message_side_effect
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.unsubscribe = AsyncMock()
mock_pubsub.close = AsyncMock()
mock_redis = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
event_bus._redis_client = mock_redis
received_events = []
async for event in event_bus.subscribe(["test:channel"], max_wait=0.5):
received_events.append(event)
# Should receive no events (invalid JSON was skipped)
assert len(received_events) == 0
@pytest.mark.asyncio
async def test_subscribe_handles_invalid_event_schema(self):
"""Test subscribe skips messages with invalid Event schema."""
event_bus = EventBus()
event_bus._connected = True
# Valid JSON but invalid Event schema (missing required fields)
invalid_event_json = json.dumps({"type": "test", "data": "incomplete"})
mock_pubsub = AsyncMock()
message_queue = [
{"type": "message", "data": invalid_event_json, "channel": "test"},
]
call_count = 0
async def get_message_side_effect(**kwargs):
nonlocal call_count
if call_count < len(message_queue):
result = message_queue[call_count]
call_count += 1
return result
raise TimeoutError()
mock_pubsub.get_message = get_message_side_effect
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.unsubscribe = AsyncMock()
mock_pubsub.close = AsyncMock()
mock_redis = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
event_bus._redis_client = mock_redis
received_events = []
async for event in event_bus.subscribe(["test:channel"], max_wait=0.5):
received_events.append(event)
# Should receive no events (invalid schema was skipped)
assert len(received_events) == 0
@pytest.mark.asyncio
async def test_subscribe_max_wait(self):
"""Test subscribe respects max_wait."""
event_bus = EventBus()
event_bus._connected = True
mock_pubsub = AsyncMock()
# Simulate slow get_message that will trigger timeout
async def slow_get_message(**kwargs):
# Sleep longer than max_wait to trigger timeout
await asyncio.sleep(1.0)
return None
mock_pubsub.get_message = slow_get_message
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.unsubscribe = AsyncMock()
mock_pubsub.close = AsyncMock()
mock_redis = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
event_bus._redis_client = mock_redis
received_events = []
start_time = asyncio.get_running_loop().time()
async for event in event_bus.subscribe(["test:channel"], max_wait=0.2):
received_events.append(event)
elapsed = asyncio.get_running_loop().time() - start_time
assert len(received_events) == 0
# Should have timed out around 0.2 seconds
assert elapsed < 1.0
class TestEventBusChannelIsolation:
"""Tests for channel isolation in EventBus."""
@pytest.mark.asyncio
async def test_events_only_received_on_subscribed_channel(self):
"""Test events are only received on subscribed channels."""
event_bus = EventBus()
event_bus._connected = True
project1_id = uuid4()
project2_id = uuid4() # noqa: F841
# Event for project 1
event1 = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project1_id,
actor_type="agent",
payload={"message": "For project 1"},
)
serialized_event1 = event1.model_dump_json()
# Mock pubsub only returns event1 for project1 channel
mock_pubsub = AsyncMock()
message_queue = [
{
"type": "message",
"data": serialized_event1,
"channel": f"project:{project1_id}",
},
]
call_count = 0
async def get_message_side_effect(**kwargs):
nonlocal call_count
if call_count < len(message_queue):
result = message_queue[call_count]
call_count += 1
return result
raise TimeoutError()
mock_pubsub.get_message = get_message_side_effect
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.unsubscribe = AsyncMock()
mock_pubsub.close = AsyncMock()
mock_redis = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
event_bus._redis_client = mock_redis
# Subscribe to project1 channel
received = []
channel = event_bus.get_project_channel(project1_id)
async for event in event_bus.subscribe([channel], max_wait=0.5):
received.append(event)
assert len(received) == 1
assert received[0].project_id == project1_id
def test_channel_names_are_unique(self):
"""Test different entity types have unique channel names."""
event_bus = EventBus()
same_id = uuid4()
project_channel = event_bus.get_project_channel(same_id)
agent_channel = event_bus.get_agent_channel(same_id)
user_channel = event_bus.get_user_channel(same_id)
# All channels should be different even with same ID
assert project_channel != agent_channel
assert agent_channel != user_channel
assert project_channel != user_channel
class TestEventBusSingleton:
"""Tests for EventBus singleton management."""
@pytest.mark.asyncio
async def test_get_event_bus_returns_singleton(self):
"""Test get_event_bus returns same instance."""
# Reset singleton
import app.services.event_bus as event_bus_module
event_bus_module._event_bus = None
bus1 = get_event_bus()
bus2 = get_event_bus()
assert bus1 is bus2
# Cleanup
event_bus_module._event_bus = None
@pytest.mark.asyncio
async def test_get_connected_event_bus(self):
"""Test get_connected_event_bus returns connected instance."""
import app.services.event_bus as event_bus_module
event_bus_module._event_bus = None
mock_redis = AsyncMock()
mock_redis.ping = AsyncMock(return_value=True)
mock_redis.pubsub = MagicMock(return_value=AsyncMock())
with patch("redis.asyncio.from_url", return_value=mock_redis):
bus = await get_connected_event_bus()
assert bus.is_connected
# Cleanup
await close_event_bus()
@pytest.mark.asyncio
async def test_close_event_bus(self):
"""Test close_event_bus cleans up singleton."""
import app.services.event_bus as event_bus_module
# Create a connected instance
event_bus_module._event_bus = EventBus()
event_bus_module._event_bus._connected = True
event_bus_module._event_bus._redis_client = AsyncMock()
event_bus_module._event_bus._redis_client.aclose = AsyncMock()
event_bus_module._event_bus._pubsub = AsyncMock()
event_bus_module._event_bus._pubsub.unsubscribe = AsyncMock()
event_bus_module._event_bus._pubsub.close = AsyncMock()
await close_event_bus()
assert event_bus_module._event_bus is None
class TestEventBusSubscribeWithCallback:
"""Tests for subscribe_with_callback method."""
@pytest.mark.asyncio
async def test_subscribe_with_callback_processes_events(self):
"""Test subscribe_with_callback calls callback for each event."""
event_bus = EventBus()
event_bus._connected = True
project_id = uuid4()
test_event = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="agent",
)
serialized_event = test_event.model_dump_json()
mock_pubsub = AsyncMock()
message_queue = [
{"type": "message", "data": serialized_event, "channel": "test"},
]
call_count = 0
async def get_message_side_effect(**kwargs):
nonlocal call_count
if call_count < len(message_queue):
result = message_queue[call_count]
call_count += 1
return result
raise TimeoutError()
mock_pubsub.get_message = get_message_side_effect
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.unsubscribe = AsyncMock()
mock_pubsub.close = AsyncMock()
mock_redis = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
event_bus._redis_client = mock_redis
received_events = []
async def callback(event: Event):
received_events.append(event)
# Create stop event and set it after a delay
stop_event = asyncio.Event()
async def stop_after_delay():
await asyncio.sleep(0.3)
stop_event.set()
# Run subscription with callback in background
task = asyncio.create_task(
event_bus.subscribe_with_callback(["test"], callback, stop_event)
)
stop_task = asyncio.create_task(stop_after_delay())
await asyncio.gather(task, stop_task, return_exceptions=True)
assert len(received_events) == 1
assert received_events[0].type == EventType.AGENT_MESSAGE
@pytest.mark.asyncio
async def test_subscribe_with_callback_handles_callback_errors(self):
"""Test subscribe_with_callback continues after callback errors."""
event_bus = EventBus()
event_bus._connected = True
project_id = uuid4()
test_event = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="agent",
)
serialized_event = test_event.model_dump_json()
mock_pubsub = AsyncMock()
message_queue = [
{"type": "message", "data": serialized_event, "channel": "test"},
{"type": "message", "data": serialized_event, "channel": "test"},
]
call_count = 0
async def get_message_side_effect(**kwargs):
nonlocal call_count
if call_count < len(message_queue):
result = message_queue[call_count]
call_count += 1
return result
raise TimeoutError()
mock_pubsub.get_message = get_message_side_effect
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.unsubscribe = AsyncMock()
mock_pubsub.close = AsyncMock()
mock_redis = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
event_bus._redis_client = mock_redis
call_count = 0
async def failing_callback(event: Event):
nonlocal call_count
call_count += 1
if call_count == 1:
raise ValueError("First call fails")
# Second call succeeds
stop_event = asyncio.Event()
async def stop_after_delay():
await asyncio.sleep(0.3)
stop_event.set()
task = asyncio.create_task(
event_bus.subscribe_with_callback(["test"], failing_callback, stop_event)
)
stop_task = asyncio.create_task(stop_after_delay())
await asyncio.gather(task, stop_task, return_exceptions=True)
# Should have processed both events despite first callback failing
assert call_count == 2
class TestEventBusSSESubscription:
"""Tests for SSE subscription functionality."""
@pytest.mark.asyncio
async def test_subscribe_sse_yields_json_strings(self):
"""Test subscribe_sse yields JSON string data."""
event_bus = EventBus()
event_bus._connected = True
project_id = uuid4()
test_event = EventBus.create_event(
event_type=EventType.AGENT_MESSAGE,
project_id=project_id,
actor_type="agent",
payload={"message": "SSE test"},
)
serialized_event = test_event.model_dump_json()
mock_pubsub = AsyncMock()
message_queue = [
{"type": "message", "data": serialized_event, "channel": "test"},
]
call_count = 0
async def get_message_side_effect(**kwargs):
nonlocal call_count
if call_count < len(message_queue):
result = message_queue[call_count]
call_count += 1
return result
# Raise TimeoutError after messages are exhausted
raise TimeoutError()
mock_pubsub.get_message = get_message_side_effect
mock_pubsub.subscribe = AsyncMock()
mock_pubsub.unsubscribe = AsyncMock()
mock_pubsub.close = AsyncMock()
mock_redis = AsyncMock()
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
event_bus._redis_client = mock_redis
received = []
count = 0
async for data in event_bus.subscribe_sse(project_id, keepalive_interval=0.1):
if data: # Skip keepalive empty strings
received.append(data)
count += 1
if count >= 2: # Get one message and one keepalive
break
assert len(received) == 1
# Should be valid JSON
parsed = json.loads(received[0])
assert parsed["type"] == "agent.message"
assert parsed["payload"]["message"] == "SSE test"