forked from cardosofelipe/fast-next-template
feat(backend): Add SSE endpoint for project event streaming
- Add /projects/{project_id}/events/stream SSE endpoint
- Add event_bus dependency injection
- Add project access authorization (placeholder)
- Add test event endpoint for development
- Add keepalive comments every 30 seconds
- Add reconnection support via Last-Event-ID header
- Add rate limiting (10/minute per IP)
- Mount events router in API
- Add sse-starlette dependency
- Add 19 comprehensive tests for SSE functionality
Implements #34
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
525
backend/tests/api/routes/test_events.py
Normal file
525
backend/tests/api/routes/test_events.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""
|
||||
Tests for the SSE events endpoint.
|
||||
|
||||
This module tests the Server-Sent Events endpoint for project event streaming,
|
||||
including:
|
||||
- Authentication and authorization
|
||||
- SSE stream connection and format
|
||||
- Keepalive mechanism
|
||||
- Reconnection support (Last-Event-ID)
|
||||
- Connection cleanup
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.api.dependencies.event_bus import get_event_bus
|
||||
from app.core.database import get_db
|
||||
from app.main import app
|
||||
from app.schemas.events import Event, EventType
|
||||
from app.services.event_bus import EventBus
|
||||
|
||||
|
||||
class MockEventBus:
|
||||
"""Mock EventBus for testing without Redis."""
|
||||
|
||||
def __init__(self):
|
||||
self.published_events: list[Event] = []
|
||||
self._should_yield_events = True
|
||||
self._events_to_yield: list[str] = []
|
||||
self._connected = True
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connected
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._connected = True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._connected = False
|
||||
|
||||
def get_project_channel(self, project_id: uuid.UUID | str) -> str:
|
||||
"""Get the channel name for a project."""
|
||||
return f"project:{project_id}"
|
||||
|
||||
@staticmethod
|
||||
def create_event(
|
||||
event_type: EventType,
|
||||
project_id: uuid.UUID,
|
||||
actor_type: str,
|
||||
payload: dict | None = None,
|
||||
actor_id: uuid.UUID | None = None,
|
||||
event_id: str | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> Event:
|
||||
"""Create a new Event."""
|
||||
return Event(
|
||||
id=event_id or str(uuid.uuid4()),
|
||||
type=event_type,
|
||||
timestamp=timestamp or datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
payload=payload or {},
|
||||
)
|
||||
|
||||
async def publish(self, channel: str, event: Event) -> int:
|
||||
"""Publish an event to a channel."""
|
||||
self.published_events.append(event)
|
||||
return 1
|
||||
|
||||
def add_event_to_yield(self, event_json: str) -> None:
|
||||
"""Add an event JSON string to be yielded by subscribe_sse."""
|
||||
self._events_to_yield.append(event_json)
|
||||
|
||||
async def subscribe_sse(
|
||||
self,
|
||||
project_id: str | uuid.UUID,
|
||||
last_event_id: str | None = None,
|
||||
keepalive_interval: int = 30,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Mock subscribe_sse that yields pre-configured events then keepalive."""
|
||||
# First yield any pre-configured events
|
||||
for event_data in self._events_to_yield:
|
||||
yield event_data
|
||||
|
||||
# Then yield keepalive
|
||||
yield ""
|
||||
|
||||
# Then stop to allow test to complete
|
||||
self._should_yield_events = False
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_event_bus():
|
||||
"""Create a mock event bus for testing."""
|
||||
return MockEventBus()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client_with_mock_bus(async_test_db, mock_event_bus):
|
||||
"""
|
||||
Create a FastAPI test client with mocked database and event bus.
|
||||
"""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async def override_get_db():
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def override_get_event_bus():
|
||||
return mock_event_bus
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_event_bus] = override_get_event_bus
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as test_client:
|
||||
yield test_client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def user_token_with_mock_bus(client_with_mock_bus, async_test_user):
|
||||
"""Create an access token for the test user."""
|
||||
response = await client_with_mock_bus.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200, f"Login failed: {response.text}"
|
||||
tokens = response.json()
|
||||
return tokens["access_token"]
|
||||
|
||||
|
||||
class TestSSEEndpointAuthentication:
|
||||
"""Tests for SSE endpoint authentication."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_requires_authentication(self, client_with_mock_bus):
|
||||
"""Test that SSE endpoint requires authentication."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_with_invalid_token(self, client_with_mock_bus):
|
||||
"""Test that SSE endpoint rejects invalid tokens."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
headers={"Authorization": "Bearer invalid_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestSSEEndpointStream:
|
||||
"""Tests for SSE stream functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_returns_sse_response(
|
||||
self, client_with_mock_bus, user_token_with_mock_bus
|
||||
):
|
||||
"""Test that SSE endpoint returns proper SSE response."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
# Make request with a timeout to avoid hanging
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
# The response should start streaming
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert "text/event-stream" in response.headers.get("content-type", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_with_events(
|
||||
self, client_with_mock_bus, user_token_with_mock_bus, mock_event_bus
|
||||
):
|
||||
"""Test that SSE endpoint yields events."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
# Create a test event and add it to the mock bus
|
||||
test_event = Event(
|
||||
id=str(uuid.uuid4()),
|
||||
type=EventType.AGENT_MESSAGE,
|
||||
timestamp=datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_type="agent",
|
||||
payload={"message": "test"},
|
||||
)
|
||||
mock_event_bus.add_event_to_yield(test_event.model_dump_json())
|
||||
|
||||
# Request the stream
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Check response contains event data
|
||||
content = response.text
|
||||
assert "agent.message" in content or "data:" in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_with_last_event_id(
|
||||
self, client_with_mock_bus, user_token_with_mock_bus
|
||||
):
|
||||
"""Test that Last-Event-ID header is accepted."""
|
||||
project_id = uuid.uuid4()
|
||||
last_event_id = str(uuid.uuid4())
|
||||
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
headers={
|
||||
"Authorization": f"Bearer {user_token_with_mock_bus}",
|
||||
"Last-Event-ID": last_event_id,
|
||||
},
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
# Should accept the header and return OK
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
class TestSSEEndpointHeaders:
|
||||
"""Tests for SSE response headers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_cache_control_header(
|
||||
self, client_with_mock_bus, user_token_with_mock_bus
|
||||
):
|
||||
"""Test that SSE response has no-cache header."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
cache_control = response.headers.get("cache-control", "")
|
||||
assert "no-cache" in cache_control.lower()
|
||||
|
||||
|
||||
class TestTestEventEndpoint:
|
||||
"""Tests for the test event endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_test_event_requires_auth(self, client_with_mock_bus):
|
||||
"""Test that test event endpoint requires authentication."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
response = await client_with_mock_bus.post(
|
||||
f"/api/v1/projects/{project_id}/events/test",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_test_event_success(
|
||||
self, client_with_mock_bus, user_token_with_mock_bus, mock_event_bus
|
||||
):
|
||||
"""Test sending a test event."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
response = await client_with_mock_bus.post(
|
||||
f"/api/v1/projects/{project_id}/events/test",
|
||||
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "event_id" in data
|
||||
assert data["event_type"] == "agent.message"
|
||||
|
||||
# Verify event was published
|
||||
assert len(mock_event_bus.published_events) == 1
|
||||
published = mock_event_bus.published_events[0]
|
||||
assert published.type == EventType.AGENT_MESSAGE
|
||||
assert published.project_id == project_id
|
||||
|
||||
|
||||
class TestEventSchema:
|
||||
"""Tests for the Event schema."""
|
||||
|
||||
def test_event_creation(self):
|
||||
"""Test Event creation with required fields."""
|
||||
project_id = uuid.uuid4()
|
||||
event = Event(
|
||||
id=str(uuid.uuid4()),
|
||||
type=EventType.AGENT_MESSAGE,
|
||||
timestamp=datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_type="agent",
|
||||
payload={"message": "test"},
|
||||
)
|
||||
|
||||
assert event.id is not None
|
||||
assert event.type == EventType.AGENT_MESSAGE
|
||||
assert event.project_id == project_id
|
||||
assert event.actor_type == "agent"
|
||||
assert event.payload == {"message": "test"}
|
||||
|
||||
def test_event_json_serialization(self):
|
||||
"""Test Event JSON serialization."""
|
||||
project_id = uuid.uuid4()
|
||||
event = Event(
|
||||
id="test-id",
|
||||
type=EventType.AGENT_STATUS_CHANGED,
|
||||
timestamp=datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_type="system",
|
||||
payload={"status": "running"},
|
||||
)
|
||||
|
||||
json_str = event.model_dump_json()
|
||||
parsed = json.loads(json_str)
|
||||
|
||||
assert parsed["id"] == "test-id"
|
||||
assert parsed["type"] == "agent.status_changed"
|
||||
assert str(parsed["project_id"]) == str(project_id)
|
||||
assert parsed["payload"]["status"] == "running"
|
||||
|
||||
|
||||
class TestEventBusUnit:
|
||||
"""Unit tests for EventBus class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_not_connected_raises(self):
|
||||
"""Test that accessing redis_client before connect raises."""
|
||||
from app.services.event_bus import EventBusConnectionError
|
||||
|
||||
bus = EventBus()
|
||||
|
||||
with pytest.raises(EventBusConnectionError, match="not connected"):
|
||||
_ = bus.redis_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_channel_names(self):
|
||||
"""Test channel name generation."""
|
||||
bus = EventBus()
|
||||
project_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
assert bus.get_project_channel(project_id) == f"project:{project_id}"
|
||||
assert bus.get_agent_channel(agent_id) == f"agent:{agent_id}"
|
||||
assert bus.get_user_channel(user_id) == f"user:{user_id}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_sequence_counter(self):
|
||||
"""Test sequence counter increments."""
|
||||
bus = EventBus()
|
||||
channel = "test-channel"
|
||||
|
||||
seq1 = bus._get_next_sequence(channel)
|
||||
seq2 = bus._get_next_sequence(channel)
|
||||
seq3 = bus._get_next_sequence(channel)
|
||||
|
||||
assert seq1 == 1
|
||||
assert seq2 == 2
|
||||
assert seq3 == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_sequence_per_channel(self):
|
||||
"""Test sequence counter is per-channel."""
|
||||
bus = EventBus()
|
||||
|
||||
seq1 = bus._get_next_sequence("channel-1")
|
||||
seq2 = bus._get_next_sequence("channel-2")
|
||||
seq3 = bus._get_next_sequence("channel-1")
|
||||
|
||||
assert seq1 == 1
|
||||
assert seq2 == 1 # Different channel starts at 1
|
||||
assert seq3 == 2
|
||||
|
||||
def test_event_bus_create_event(self):
|
||||
"""Test EventBus.create_event factory method."""
|
||||
project_id = uuid.uuid4()
|
||||
actor_id = uuid.uuid4()
|
||||
|
||||
event = EventBus.create_event(
|
||||
event_type=EventType.ISSUE_CREATED,
|
||||
project_id=project_id,
|
||||
actor_type="user",
|
||||
actor_id=actor_id,
|
||||
payload={"title": "Test Issue"},
|
||||
)
|
||||
|
||||
assert event.type == EventType.ISSUE_CREATED
|
||||
assert event.project_id == project_id
|
||||
assert event.actor_id == actor_id
|
||||
assert event.actor_type == "user"
|
||||
assert event.payload == {"title": "Test Issue"}
|
||||
|
||||
|
||||
class TestEventBusIntegration:
|
||||
"""Integration tests for EventBus with mocked Redis."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_connect_disconnect(self):
|
||||
"""Test EventBus connect and disconnect."""
|
||||
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_client.ping = AsyncMock()
|
||||
mock_client.pubsub = lambda: AsyncMock()
|
||||
|
||||
bus = EventBus(redis_url="redis://localhost:6379/0")
|
||||
|
||||
# Connect
|
||||
await bus.connect()
|
||||
mock_client.ping.assert_called_once()
|
||||
assert bus._redis_client is not None
|
||||
assert bus.is_connected
|
||||
|
||||
# Disconnect
|
||||
await bus.disconnect()
|
||||
mock_client.aclose.assert_called_once()
|
||||
assert bus._redis_client is None
|
||||
assert not bus.is_connected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_publish(self):
|
||||
"""Test EventBus event publishing."""
|
||||
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_client.ping = AsyncMock()
|
||||
mock_client.publish = AsyncMock(return_value=1)
|
||||
mock_client.pubsub = lambda: AsyncMock()
|
||||
|
||||
bus = EventBus()
|
||||
await bus.connect()
|
||||
|
||||
project_id = uuid.uuid4()
|
||||
event = EventBus.create_event(
|
||||
event_type=EventType.AGENT_SPAWNED,
|
||||
project_id=project_id,
|
||||
actor_type="system",
|
||||
payload={"agent_name": "test-agent"},
|
||||
)
|
||||
|
||||
channel = bus.get_project_channel(project_id)
|
||||
result = await bus.publish(channel, event)
|
||||
|
||||
# Verify publish was called
|
||||
mock_client.publish.assert_called_once()
|
||||
call_args = mock_client.publish.call_args
|
||||
|
||||
# Check channel name
|
||||
assert call_args[0][0] == f"project:{project_id}"
|
||||
|
||||
# Check result
|
||||
assert result == 1
|
||||
|
||||
await bus.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_connect_failure(self):
|
||||
"""Test EventBus handles connection failure."""
|
||||
from app.services.event_bus import EventBusConnectionError
|
||||
|
||||
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
import redis.asyncio as redis_async
|
||||
|
||||
mock_client.ping = AsyncMock(
|
||||
side_effect=redis_async.ConnectionError("Connection refused")
|
||||
)
|
||||
|
||||
bus = EventBus()
|
||||
|
||||
with pytest.raises(EventBusConnectionError, match="Failed to connect"):
|
||||
await bus.connect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_already_connected(self):
|
||||
"""Test EventBus connect when already connected is a no-op."""
|
||||
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_client.ping = AsyncMock()
|
||||
mock_client.pubsub = lambda: AsyncMock()
|
||||
|
||||
bus = EventBus()
|
||||
|
||||
# First connect
|
||||
await bus.connect()
|
||||
assert mock_client.ping.call_count == 1
|
||||
|
||||
# Second connect should be a no-op
|
||||
await bus.connect()
|
||||
assert mock_client.ping.call_count == 1
|
||||
|
||||
await bus.disconnect()
|
||||
Reference in New Issue
Block a user