forked from cardosofelipe/fast-next-template
- Add EventBus class for real-time event communication - Add Event schema with type-safe event types (agent, issue, sprint events) - Add typed payload schemas (AgentSpawnedPayload, AgentMessagePayload) - Add channel helpers for project/agent/user scoping - Add subscribe_sse generator for SSE streaming - Add reconnection support via Last-Event-ID - Add keepalive mechanism for connection health - Add 44 comprehensive tests with mocked Redis Implements #33 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1036 lines
34 KiB
Python
1036 lines
34 KiB
Python
# tests/services/test_event_bus.py
|
|
"""
|
|
Tests for the EventBus service.
|
|
|
|
These tests verify:
|
|
- Event creation and serialization
|
|
- Publishing events to channels
|
|
- Subscribing to channels and receiving events
|
|
- Channel isolation (events only go to intended channels)
|
|
- Error handling for connection failures
|
|
- Multiple subscriptions and concurrent publishing
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
from datetime import UTC, datetime
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from app.schemas.events import (
|
|
AgentMessagePayload,
|
|
AgentSpawnedPayload,
|
|
Event,
|
|
EventType,
|
|
)
|
|
from app.services.event_bus import (
|
|
EventBus,
|
|
EventBusConnectionError,
|
|
EventBusPublishError,
|
|
close_event_bus,
|
|
get_connected_event_bus,
|
|
get_event_bus,
|
|
)
|
|
|
|
|
|
class TestEventSchema:
|
|
"""Tests for Event schema and payload schemas."""
|
|
|
|
def test_create_event_with_all_fields(self):
|
|
"""Test creating an Event with all fields populated."""
|
|
project_id = uuid4()
|
|
actor_id = uuid4()
|
|
timestamp = datetime.now(UTC)
|
|
|
|
event = Event(
|
|
id="test-event-id",
|
|
type=EventType.AGENT_MESSAGE,
|
|
timestamp=timestamp,
|
|
project_id=project_id,
|
|
actor_id=actor_id,
|
|
actor_type="agent",
|
|
payload={"message": "Hello, world!"},
|
|
)
|
|
|
|
assert event.id == "test-event-id"
|
|
assert event.type == EventType.AGENT_MESSAGE
|
|
assert event.timestamp == timestamp
|
|
assert event.project_id == project_id
|
|
assert event.actor_id == actor_id
|
|
assert event.actor_type == "agent"
|
|
assert event.payload == {"message": "Hello, world!"}
|
|
|
|
def test_create_event_with_minimal_fields(self):
|
|
"""Test creating an Event with only required fields."""
|
|
project_id = uuid4()
|
|
|
|
event = Event(
|
|
id="test-event-id",
|
|
type=EventType.ISSUE_CREATED,
|
|
timestamp=datetime.now(UTC),
|
|
project_id=project_id,
|
|
actor_type="system",
|
|
)
|
|
|
|
assert event.id == "test-event-id"
|
|
assert event.actor_id is None
|
|
assert event.payload == {}
|
|
|
|
def test_event_serialization_json(self):
|
|
"""Test Event serializes to JSON correctly."""
|
|
project_id = uuid4()
|
|
event = Event(
|
|
id="test-event-id",
|
|
type=EventType.SPRINT_STARTED,
|
|
timestamp=datetime(2024, 1, 15, 10, 30, 0, tzinfo=UTC),
|
|
project_id=project_id,
|
|
actor_type="user",
|
|
payload={"sprint_name": "Sprint 1"},
|
|
)
|
|
|
|
json_str = event.model_dump_json()
|
|
data = json.loads(json_str)
|
|
|
|
assert data["id"] == "test-event-id"
|
|
assert data["type"] == "sprint.started"
|
|
assert data["actor_type"] == "user"
|
|
assert data["payload"]["sprint_name"] == "Sprint 1"
|
|
|
|
def test_event_deserialization_json(self):
|
|
"""Test Event deserializes from JSON correctly."""
|
|
project_id = uuid4()
|
|
json_data = {
|
|
"id": "test-event-id",
|
|
"type": "agent.spawned",
|
|
"timestamp": "2024-01-15T10:30:00Z",
|
|
"project_id": str(project_id),
|
|
"actor_id": None,
|
|
"actor_type": "system",
|
|
"payload": {"agent_name": "PO Agent"},
|
|
}
|
|
|
|
event = Event.model_validate(json_data)
|
|
|
|
assert event.id == "test-event-id"
|
|
assert event.type == EventType.AGENT_SPAWNED
|
|
assert event.project_id == project_id
|
|
assert event.payload["agent_name"] == "PO Agent"
|
|
|
|
def test_event_type_enum_values(self):
|
|
"""Test all EventType enum values are accessible."""
|
|
# Agent events
|
|
assert EventType.AGENT_SPAWNED.value == "agent.spawned"
|
|
assert EventType.AGENT_STATUS_CHANGED.value == "agent.status_changed"
|
|
assert EventType.AGENT_MESSAGE.value == "agent.message"
|
|
assert EventType.AGENT_TERMINATED.value == "agent.terminated"
|
|
|
|
# Issue events
|
|
assert EventType.ISSUE_CREATED.value == "issue.created"
|
|
assert EventType.ISSUE_UPDATED.value == "issue.updated"
|
|
assert EventType.ISSUE_ASSIGNED.value == "issue.assigned"
|
|
assert EventType.ISSUE_CLOSED.value == "issue.closed"
|
|
|
|
# Sprint events
|
|
assert EventType.SPRINT_STARTED.value == "sprint.started"
|
|
assert EventType.SPRINT_COMPLETED.value == "sprint.completed"
|
|
|
|
# Approval events
|
|
assert EventType.APPROVAL_REQUESTED.value == "approval.requested"
|
|
assert EventType.APPROVAL_GRANTED.value == "approval.granted"
|
|
assert EventType.APPROVAL_DENIED.value == "approval.denied"
|
|
|
|
def test_actor_type_literal(self):
|
|
"""Test ActorType accepts valid values."""
|
|
project_id = uuid4()
|
|
|
|
# Valid actor types
|
|
for actor_type in ["agent", "user", "system"]:
|
|
event = Event(
|
|
id="test",
|
|
type=EventType.AGENT_MESSAGE,
|
|
timestamp=datetime.now(UTC),
|
|
project_id=project_id,
|
|
actor_type=actor_type,
|
|
)
|
|
assert event.actor_type == actor_type
|
|
|
|
def test_invalid_actor_type_rejected(self):
|
|
"""Test invalid actor type is rejected."""
|
|
project_id = uuid4()
|
|
|
|
with pytest.raises(ValidationError):
|
|
Event(
|
|
id="test",
|
|
type=EventType.AGENT_MESSAGE,
|
|
timestamp=datetime.now(UTC),
|
|
project_id=project_id,
|
|
actor_type="invalid", # Invalid actor type
|
|
)
|
|
|
|
|
|
class TestAgentPayloadSchemas:
|
|
"""Tests for agent-specific payload schemas."""
|
|
|
|
def test_agent_spawned_payload(self):
|
|
"""Test AgentSpawnedPayload schema."""
|
|
payload = AgentSpawnedPayload(
|
|
agent_instance_id=uuid4(),
|
|
agent_type_id=uuid4(),
|
|
agent_name="Product Owner Agent",
|
|
role="product_owner",
|
|
)
|
|
|
|
assert payload.agent_name == "Product Owner Agent"
|
|
assert payload.role == "product_owner"
|
|
|
|
def test_agent_message_payload(self):
|
|
"""Test AgentMessagePayload schema."""
|
|
payload = AgentMessagePayload(
|
|
agent_instance_id=uuid4(),
|
|
message="Processing requirements...",
|
|
message_type="info",
|
|
metadata={"tokens_used": 150},
|
|
)
|
|
|
|
assert payload.message == "Processing requirements..."
|
|
assert payload.message_type == "info"
|
|
assert payload.metadata["tokens_used"] == 150
|
|
|
|
def test_agent_message_payload_defaults(self):
|
|
"""Test AgentMessagePayload has correct defaults."""
|
|
payload = AgentMessagePayload(
|
|
agent_instance_id=uuid4(),
|
|
message="Test message",
|
|
)
|
|
|
|
assert payload.message_type == "info"
|
|
assert payload.metadata == {}
|
|
|
|
|
|
class TestEventBusChannels:
|
|
"""Tests for EventBus channel helper methods."""
|
|
|
|
def test_get_project_channel_with_uuid(self):
|
|
"""Test get_project_channel with UUID."""
|
|
event_bus = EventBus()
|
|
project_id = uuid4()
|
|
|
|
channel = event_bus.get_project_channel(project_id)
|
|
|
|
assert channel == f"project:{project_id}"
|
|
|
|
def test_get_project_channel_with_string(self):
|
|
"""Test get_project_channel with string."""
|
|
event_bus = EventBus()
|
|
|
|
channel = event_bus.get_project_channel("test-project-123")
|
|
|
|
assert channel == "project:test-project-123"
|
|
|
|
def test_get_agent_channel_with_uuid(self):
|
|
"""Test get_agent_channel with UUID."""
|
|
event_bus = EventBus()
|
|
agent_id = uuid4()
|
|
|
|
channel = event_bus.get_agent_channel(agent_id)
|
|
|
|
assert channel == f"agent:{agent_id}"
|
|
|
|
def test_get_user_channel_with_uuid(self):
|
|
"""Test get_user_channel with UUID."""
|
|
event_bus = EventBus()
|
|
user_id = uuid4()
|
|
|
|
channel = event_bus.get_user_channel(user_id)
|
|
|
|
assert channel == f"user:{user_id}"
|
|
|
|
def test_channel_prefixes(self):
|
|
"""Test channel prefix constants."""
|
|
assert EventBus.PROJECT_CHANNEL_PREFIX == "project"
|
|
assert EventBus.AGENT_CHANNEL_PREFIX == "agent"
|
|
assert EventBus.USER_CHANNEL_PREFIX == "user"
|
|
assert EventBus.GLOBAL_CHANNEL == "syndarix:global"
|
|
|
|
|
|
class TestEventBusCreateEvent:
|
|
"""Tests for EventBus.create_event factory method."""
|
|
|
|
def test_create_event_with_required_fields(self):
|
|
"""Test create_event with only required fields."""
|
|
project_id = uuid4()
|
|
|
|
event = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="agent",
|
|
)
|
|
|
|
assert event.type == EventType.AGENT_MESSAGE
|
|
assert event.project_id == project_id
|
|
assert event.actor_type == "agent"
|
|
assert event.actor_id is None
|
|
assert event.payload == {}
|
|
# Auto-generated fields
|
|
assert event.id is not None
|
|
assert event.timestamp is not None
|
|
|
|
def test_create_event_with_all_fields(self):
|
|
"""Test create_event with all fields."""
|
|
project_id = uuid4()
|
|
actor_id = uuid4()
|
|
timestamp = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC)
|
|
|
|
event = EventBus.create_event(
|
|
event_type=EventType.APPROVAL_REQUESTED,
|
|
project_id=project_id,
|
|
actor_type="agent",
|
|
payload={"approval_type": "code_review"},
|
|
actor_id=actor_id,
|
|
event_id="custom-event-id",
|
|
timestamp=timestamp,
|
|
)
|
|
|
|
assert event.id == "custom-event-id"
|
|
assert event.type == EventType.APPROVAL_REQUESTED
|
|
assert event.timestamp == timestamp
|
|
assert event.project_id == project_id
|
|
assert event.actor_id == actor_id
|
|
assert event.actor_type == "agent"
|
|
assert event.payload == {"approval_type": "code_review"}
|
|
|
|
def test_create_event_generates_unique_ids(self):
|
|
"""Test create_event generates unique IDs."""
|
|
project_id = uuid4()
|
|
|
|
event1 = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="system",
|
|
)
|
|
event2 = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="system",
|
|
)
|
|
|
|
assert event1.id != event2.id
|
|
|
|
|
|
class TestEventBusConnection:
|
|
"""Tests for EventBus connection management."""
|
|
|
|
def test_initial_state(self):
|
|
"""Test EventBus starts disconnected."""
|
|
event_bus = EventBus()
|
|
|
|
assert not event_bus.is_connected
|
|
assert event_bus._redis_client is None
|
|
assert event_bus._pubsub is None
|
|
|
|
def test_redis_client_raises_when_disconnected(self):
|
|
"""Test accessing redis_client raises when not connected."""
|
|
event_bus = EventBus()
|
|
|
|
with pytest.raises(EventBusConnectionError) as exc_info:
|
|
_ = event_bus.redis_client
|
|
|
|
assert "not connected" in str(exc_info.value).lower()
|
|
|
|
def test_pubsub_raises_when_disconnected(self):
|
|
"""Test accessing pubsub raises when not connected."""
|
|
event_bus = EventBus()
|
|
|
|
with pytest.raises(EventBusConnectionError) as exc_info:
|
|
_ = event_bus.pubsub
|
|
|
|
assert "not connected" in str(exc_info.value).lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_success(self):
|
|
"""Test successful Redis connection."""
|
|
event_bus = EventBus()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock(return_value=True)
|
|
mock_redis.pubsub = MagicMock(return_value=AsyncMock())
|
|
|
|
with patch("redis.asyncio.from_url", return_value=mock_redis):
|
|
await event_bus.connect()
|
|
|
|
assert event_bus.is_connected
|
|
mock_redis.ping.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_already_connected(self):
|
|
"""Test connect when already connected is idempotent."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
event_bus._redis_client = AsyncMock()
|
|
|
|
# Should not raise or create new connection
|
|
await event_bus.connect()
|
|
|
|
assert event_bus.is_connected
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_failure(self):
|
|
"""Test connection failure raises appropriate error."""
|
|
import redis.asyncio as redis_async
|
|
|
|
event_bus = EventBus()
|
|
|
|
with patch(
|
|
"redis.asyncio.from_url",
|
|
side_effect=redis_async.ConnectionError("Connection refused"),
|
|
):
|
|
with pytest.raises(EventBusConnectionError) as exc_info:
|
|
await event_bus.connect()
|
|
|
|
assert "Connection refused" in str(exc_info.value)
|
|
assert not event_bus.is_connected
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect(self):
|
|
"""Test disconnection cleans up resources."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
event_bus._redis_client = AsyncMock()
|
|
event_bus._redis_client.aclose = AsyncMock()
|
|
event_bus._pubsub = AsyncMock()
|
|
event_bus._pubsub.unsubscribe = AsyncMock()
|
|
event_bus._pubsub.close = AsyncMock()
|
|
|
|
await event_bus.disconnect()
|
|
|
|
assert not event_bus.is_connected
|
|
assert event_bus._redis_client is None
|
|
assert event_bus._pubsub is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connection_context_manager(self):
|
|
"""Test connection context manager."""
|
|
event_bus = EventBus()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock(return_value=True)
|
|
mock_redis.pubsub = MagicMock(return_value=AsyncMock())
|
|
mock_redis.aclose = AsyncMock()
|
|
|
|
with patch("redis.asyncio.from_url", return_value=mock_redis):
|
|
async with event_bus.connection() as bus:
|
|
assert bus.is_connected
|
|
|
|
# After context, should be disconnected
|
|
assert not event_bus.is_connected
|
|
|
|
|
|
class TestEventBusPublish:
|
|
"""Tests for EventBus publish functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_requires_connection(self):
|
|
"""Test publish raises when not connected."""
|
|
event_bus = EventBus()
|
|
project_id = uuid4()
|
|
event = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="system",
|
|
)
|
|
|
|
with pytest.raises(EventBusConnectionError):
|
|
await event_bus.publish("test:channel", event)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_success(self):
|
|
"""Test successful event publishing."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.publish = AsyncMock(return_value=2) # 2 subscribers
|
|
event_bus._redis_client = mock_redis
|
|
|
|
project_id = uuid4()
|
|
event = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="agent",
|
|
payload={"message": "Test message"},
|
|
)
|
|
|
|
subscriber_count = await event_bus.publish("project:test", event)
|
|
|
|
assert subscriber_count == 2
|
|
mock_redis.publish.assert_called_once()
|
|
|
|
# Verify the published message is valid JSON
|
|
call_args = mock_redis.publish.call_args
|
|
published_data = call_args[0][1]
|
|
parsed = json.loads(published_data)
|
|
assert parsed["type"] == "agent.message"
|
|
assert parsed["payload"]["message"] == "Test message"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_to_project(self):
|
|
"""Test publish_to_project convenience method."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.publish = AsyncMock(return_value=1)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
project_id = uuid4()
|
|
event = EventBus.create_event(
|
|
event_type=EventType.ISSUE_CREATED,
|
|
project_id=project_id,
|
|
actor_type="user",
|
|
)
|
|
|
|
await event_bus.publish_to_project(event)
|
|
|
|
# Verify published to correct channel
|
|
call_args = mock_redis.publish.call_args
|
|
channel = call_args[0][0]
|
|
assert channel == f"project:{project_id}"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_multi(self):
|
|
"""Test publishing to multiple channels."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.publish = AsyncMock(return_value=1)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
project_id = uuid4()
|
|
event = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="agent",
|
|
)
|
|
|
|
channels = ["project:1", "agent:2", "user:3"]
|
|
results = await event_bus.publish_multi(channels, event)
|
|
|
|
assert len(results) == 3
|
|
assert all(count == 1 for count in results.values())
|
|
assert mock_redis.publish.call_count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_publish_redis_error(self):
|
|
"""Test publish handles Redis errors."""
|
|
import redis.asyncio as redis_async
|
|
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.publish = AsyncMock(
|
|
side_effect=redis_async.RedisError("Publish failed")
|
|
)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
project_id = uuid4()
|
|
event = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="system",
|
|
)
|
|
|
|
with pytest.raises(EventBusPublishError) as exc_info:
|
|
await event_bus.publish("test:channel", event)
|
|
|
|
assert "Publish failed" in str(exc_info.value)
|
|
|
|
|
|
class TestEventBusSubscribe:
|
|
"""Tests for EventBus subscribe functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_requires_connection(self):
|
|
"""Test subscribe raises when not connected."""
|
|
event_bus = EventBus()
|
|
|
|
with pytest.raises(EventBusConnectionError):
|
|
async for _ in event_bus.subscribe(["test:channel"]):
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_receives_events(self):
|
|
"""Test subscribing to channels receives events."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
project_id = uuid4()
|
|
test_event = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="agent",
|
|
payload={"message": "Test"},
|
|
)
|
|
serialized_event = test_event.model_dump_json()
|
|
|
|
# Create mock pubsub that returns one message then times out
|
|
mock_pubsub = AsyncMock()
|
|
message_queue = [
|
|
{"type": "message", "data": serialized_event, "channel": "test:channel"},
|
|
]
|
|
call_count = 0
|
|
|
|
async def get_message_side_effect(**kwargs):
|
|
nonlocal call_count
|
|
if call_count < len(message_queue):
|
|
result = message_queue[call_count]
|
|
call_count += 1
|
|
return result
|
|
# After messages exhausted, raise TimeoutError to end subscription
|
|
raise TimeoutError()
|
|
|
|
mock_pubsub.get_message = get_message_side_effect
|
|
mock_pubsub.subscribe = AsyncMock()
|
|
mock_pubsub.unsubscribe = AsyncMock()
|
|
mock_pubsub.close = AsyncMock()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
received_events = []
|
|
# Use max_wait to ensure test doesn't hang
|
|
async for event in event_bus.subscribe(["test:channel"], max_wait=0.5):
|
|
received_events.append(event)
|
|
|
|
assert len(received_events) == 1
|
|
assert received_events[0].type == EventType.AGENT_MESSAGE
|
|
assert received_events[0].payload["message"] == "Test"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_handles_invalid_json(self):
|
|
"""Test subscribe skips invalid JSON messages."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
# Create mock pubsub with invalid JSON
|
|
mock_pubsub = AsyncMock()
|
|
message_queue = [
|
|
{"type": "message", "data": "not valid json", "channel": "test"},
|
|
]
|
|
call_count = 0
|
|
|
|
async def get_message_side_effect(**kwargs):
|
|
nonlocal call_count
|
|
if call_count < len(message_queue):
|
|
result = message_queue[call_count]
|
|
call_count += 1
|
|
return result
|
|
raise TimeoutError()
|
|
|
|
mock_pubsub.get_message = get_message_side_effect
|
|
mock_pubsub.subscribe = AsyncMock()
|
|
mock_pubsub.unsubscribe = AsyncMock()
|
|
mock_pubsub.close = AsyncMock()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
received_events = []
|
|
async for event in event_bus.subscribe(["test:channel"], max_wait=0.5):
|
|
received_events.append(event)
|
|
|
|
# Should receive no events (invalid JSON was skipped)
|
|
assert len(received_events) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_handles_invalid_event_schema(self):
|
|
"""Test subscribe skips messages with invalid Event schema."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
# Valid JSON but invalid Event schema (missing required fields)
|
|
invalid_event_json = json.dumps({"type": "test", "data": "incomplete"})
|
|
|
|
mock_pubsub = AsyncMock()
|
|
message_queue = [
|
|
{"type": "message", "data": invalid_event_json, "channel": "test"},
|
|
]
|
|
call_count = 0
|
|
|
|
async def get_message_side_effect(**kwargs):
|
|
nonlocal call_count
|
|
if call_count < len(message_queue):
|
|
result = message_queue[call_count]
|
|
call_count += 1
|
|
return result
|
|
raise TimeoutError()
|
|
|
|
mock_pubsub.get_message = get_message_side_effect
|
|
mock_pubsub.subscribe = AsyncMock()
|
|
mock_pubsub.unsubscribe = AsyncMock()
|
|
mock_pubsub.close = AsyncMock()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
received_events = []
|
|
async for event in event_bus.subscribe(["test:channel"], max_wait=0.5):
|
|
received_events.append(event)
|
|
|
|
# Should receive no events (invalid schema was skipped)
|
|
assert len(received_events) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_max_wait(self):
|
|
"""Test subscribe respects max_wait."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
mock_pubsub = AsyncMock()
|
|
|
|
# Simulate slow get_message that will trigger timeout
|
|
async def slow_get_message(**kwargs):
|
|
# Sleep longer than max_wait to trigger timeout
|
|
await asyncio.sleep(1.0)
|
|
return None
|
|
|
|
mock_pubsub.get_message = slow_get_message
|
|
mock_pubsub.subscribe = AsyncMock()
|
|
mock_pubsub.unsubscribe = AsyncMock()
|
|
mock_pubsub.close = AsyncMock()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
received_events = []
|
|
start_time = asyncio.get_running_loop().time()
|
|
async for event in event_bus.subscribe(["test:channel"], max_wait=0.2):
|
|
received_events.append(event)
|
|
elapsed = asyncio.get_running_loop().time() - start_time
|
|
|
|
assert len(received_events) == 0
|
|
# Should have timed out around 0.2 seconds
|
|
assert elapsed < 1.0
|
|
|
|
|
|
class TestEventBusChannelIsolation:
|
|
"""Tests for channel isolation in EventBus."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_events_only_received_on_subscribed_channel(self):
|
|
"""Test events are only received on subscribed channels."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
project1_id = uuid4()
|
|
project2_id = uuid4() # noqa: F841
|
|
|
|
# Event for project 1
|
|
event1 = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project1_id,
|
|
actor_type="agent",
|
|
payload={"message": "For project 1"},
|
|
)
|
|
serialized_event1 = event1.model_dump_json()
|
|
|
|
# Mock pubsub only returns event1 for project1 channel
|
|
mock_pubsub = AsyncMock()
|
|
message_queue = [
|
|
{
|
|
"type": "message",
|
|
"data": serialized_event1,
|
|
"channel": f"project:{project1_id}",
|
|
},
|
|
]
|
|
call_count = 0
|
|
|
|
async def get_message_side_effect(**kwargs):
|
|
nonlocal call_count
|
|
if call_count < len(message_queue):
|
|
result = message_queue[call_count]
|
|
call_count += 1
|
|
return result
|
|
raise TimeoutError()
|
|
|
|
mock_pubsub.get_message = get_message_side_effect
|
|
mock_pubsub.subscribe = AsyncMock()
|
|
mock_pubsub.unsubscribe = AsyncMock()
|
|
mock_pubsub.close = AsyncMock()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
# Subscribe to project1 channel
|
|
received = []
|
|
channel = event_bus.get_project_channel(project1_id)
|
|
async for event in event_bus.subscribe([channel], max_wait=0.5):
|
|
received.append(event)
|
|
|
|
assert len(received) == 1
|
|
assert received[0].project_id == project1_id
|
|
|
|
def test_channel_names_are_unique(self):
|
|
"""Test different entity types have unique channel names."""
|
|
event_bus = EventBus()
|
|
same_id = uuid4()
|
|
|
|
project_channel = event_bus.get_project_channel(same_id)
|
|
agent_channel = event_bus.get_agent_channel(same_id)
|
|
user_channel = event_bus.get_user_channel(same_id)
|
|
|
|
# All channels should be different even with same ID
|
|
assert project_channel != agent_channel
|
|
assert agent_channel != user_channel
|
|
assert project_channel != user_channel
|
|
|
|
|
|
class TestEventBusSingleton:
|
|
"""Tests for EventBus singleton management."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_event_bus_returns_singleton(self):
|
|
"""Test get_event_bus returns same instance."""
|
|
# Reset singleton
|
|
import app.services.event_bus as event_bus_module
|
|
|
|
event_bus_module._event_bus = None
|
|
|
|
bus1 = get_event_bus()
|
|
bus2 = get_event_bus()
|
|
|
|
assert bus1 is bus2
|
|
|
|
# Cleanup
|
|
event_bus_module._event_bus = None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_connected_event_bus(self):
|
|
"""Test get_connected_event_bus returns connected instance."""
|
|
import app.services.event_bus as event_bus_module
|
|
|
|
event_bus_module._event_bus = None
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.ping = AsyncMock(return_value=True)
|
|
mock_redis.pubsub = MagicMock(return_value=AsyncMock())
|
|
|
|
with patch("redis.asyncio.from_url", return_value=mock_redis):
|
|
bus = await get_connected_event_bus()
|
|
assert bus.is_connected
|
|
|
|
# Cleanup
|
|
await close_event_bus()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_event_bus(self):
|
|
"""Test close_event_bus cleans up singleton."""
|
|
import app.services.event_bus as event_bus_module
|
|
|
|
# Create a connected instance
|
|
event_bus_module._event_bus = EventBus()
|
|
event_bus_module._event_bus._connected = True
|
|
event_bus_module._event_bus._redis_client = AsyncMock()
|
|
event_bus_module._event_bus._redis_client.aclose = AsyncMock()
|
|
event_bus_module._event_bus._pubsub = AsyncMock()
|
|
event_bus_module._event_bus._pubsub.unsubscribe = AsyncMock()
|
|
event_bus_module._event_bus._pubsub.close = AsyncMock()
|
|
|
|
await close_event_bus()
|
|
|
|
assert event_bus_module._event_bus is None
|
|
|
|
|
|
class TestEventBusSubscribeWithCallback:
|
|
"""Tests for subscribe_with_callback method."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_with_callback_processes_events(self):
|
|
"""Test subscribe_with_callback calls callback for each event."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
project_id = uuid4()
|
|
test_event = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="agent",
|
|
)
|
|
serialized_event = test_event.model_dump_json()
|
|
|
|
mock_pubsub = AsyncMock()
|
|
message_queue = [
|
|
{"type": "message", "data": serialized_event, "channel": "test"},
|
|
]
|
|
call_count = 0
|
|
|
|
async def get_message_side_effect(**kwargs):
|
|
nonlocal call_count
|
|
if call_count < len(message_queue):
|
|
result = message_queue[call_count]
|
|
call_count += 1
|
|
return result
|
|
raise TimeoutError()
|
|
|
|
mock_pubsub.get_message = get_message_side_effect
|
|
mock_pubsub.subscribe = AsyncMock()
|
|
mock_pubsub.unsubscribe = AsyncMock()
|
|
mock_pubsub.close = AsyncMock()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
received_events = []
|
|
|
|
async def callback(event: Event):
|
|
received_events.append(event)
|
|
|
|
# Create stop event and set it after a delay
|
|
stop_event = asyncio.Event()
|
|
|
|
async def stop_after_delay():
|
|
await asyncio.sleep(0.3)
|
|
stop_event.set()
|
|
|
|
# Run subscription with callback in background
|
|
task = asyncio.create_task(
|
|
event_bus.subscribe_with_callback(["test"], callback, stop_event)
|
|
)
|
|
stop_task = asyncio.create_task(stop_after_delay())
|
|
|
|
await asyncio.gather(task, stop_task, return_exceptions=True)
|
|
|
|
assert len(received_events) == 1
|
|
assert received_events[0].type == EventType.AGENT_MESSAGE
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_with_callback_handles_callback_errors(self):
|
|
"""Test subscribe_with_callback continues after callback errors."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
project_id = uuid4()
|
|
test_event = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="agent",
|
|
)
|
|
serialized_event = test_event.model_dump_json()
|
|
|
|
mock_pubsub = AsyncMock()
|
|
message_queue = [
|
|
{"type": "message", "data": serialized_event, "channel": "test"},
|
|
{"type": "message", "data": serialized_event, "channel": "test"},
|
|
]
|
|
call_count = 0
|
|
|
|
async def get_message_side_effect(**kwargs):
|
|
nonlocal call_count
|
|
if call_count < len(message_queue):
|
|
result = message_queue[call_count]
|
|
call_count += 1
|
|
return result
|
|
raise TimeoutError()
|
|
|
|
mock_pubsub.get_message = get_message_side_effect
|
|
mock_pubsub.subscribe = AsyncMock()
|
|
mock_pubsub.unsubscribe = AsyncMock()
|
|
mock_pubsub.close = AsyncMock()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
call_count = 0
|
|
|
|
async def failing_callback(event: Event):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
raise ValueError("First call fails")
|
|
# Second call succeeds
|
|
|
|
stop_event = asyncio.Event()
|
|
|
|
async def stop_after_delay():
|
|
await asyncio.sleep(0.3)
|
|
stop_event.set()
|
|
|
|
task = asyncio.create_task(
|
|
event_bus.subscribe_with_callback(["test"], failing_callback, stop_event)
|
|
)
|
|
stop_task = asyncio.create_task(stop_after_delay())
|
|
|
|
await asyncio.gather(task, stop_task, return_exceptions=True)
|
|
|
|
# Should have processed both events despite first callback failing
|
|
assert call_count == 2
|
|
|
|
|
|
class TestEventBusSSESubscription:
|
|
"""Tests for SSE subscription functionality."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_subscribe_sse_yields_json_strings(self):
|
|
"""Test subscribe_sse yields JSON string data."""
|
|
event_bus = EventBus()
|
|
event_bus._connected = True
|
|
|
|
project_id = uuid4()
|
|
test_event = EventBus.create_event(
|
|
event_type=EventType.AGENT_MESSAGE,
|
|
project_id=project_id,
|
|
actor_type="agent",
|
|
payload={"message": "SSE test"},
|
|
)
|
|
serialized_event = test_event.model_dump_json()
|
|
|
|
mock_pubsub = AsyncMock()
|
|
message_queue = [
|
|
{"type": "message", "data": serialized_event, "channel": "test"},
|
|
]
|
|
call_count = 0
|
|
|
|
async def get_message_side_effect(**kwargs):
|
|
nonlocal call_count
|
|
if call_count < len(message_queue):
|
|
result = message_queue[call_count]
|
|
call_count += 1
|
|
return result
|
|
# Raise TimeoutError after messages are exhausted
|
|
raise TimeoutError()
|
|
|
|
mock_pubsub.get_message = get_message_side_effect
|
|
mock_pubsub.subscribe = AsyncMock()
|
|
mock_pubsub.unsubscribe = AsyncMock()
|
|
mock_pubsub.close = AsyncMock()
|
|
|
|
mock_redis = AsyncMock()
|
|
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
|
event_bus._redis_client = mock_redis
|
|
|
|
received = []
|
|
count = 0
|
|
async for data in event_bus.subscribe_sse(project_id, keepalive_interval=0.1):
|
|
if data: # Skip keepalive empty strings
|
|
received.append(data)
|
|
count += 1
|
|
if count >= 2: # Get one message and one keepalive
|
|
break
|
|
|
|
assert len(received) == 1
|
|
# Should be valid JSON
|
|
parsed = json.loads(received[0])
|
|
assert parsed["type"] == "agent.message"
|
|
assert parsed["payload"]["message"] == "SSE test"
|