""" Tests for the SSE events endpoint. This module tests the Server-Sent Events endpoint for project event streaming, including: - Authentication and authorization - SSE stream connection and format - Keepalive mechanism - Reconnection support (Last-Event-ID) - Connection cleanup """ import json import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime from unittest.mock import AsyncMock, patch import pytest import pytest_asyncio from fastapi import status from httpx import ASGITransport, AsyncClient from app.api.dependencies.event_bus import get_event_bus from app.core.database import get_db from app.main import app from app.schemas.events import Event, EventType from app.services.event_bus import EventBus class MockEventBus: """Mock EventBus for testing without Redis.""" def __init__(self): self.published_events: list[Event] = [] self._should_yield_events = True self._events_to_yield: list[str] = [] self._connected = True @property def is_connected(self) -> bool: return self._connected async def connect(self) -> None: self._connected = True async def disconnect(self) -> None: self._connected = False def get_project_channel(self, project_id: uuid.UUID | str) -> str: """Get the channel name for a project.""" return f"project:{project_id}" @staticmethod def create_event( event_type: EventType, project_id: uuid.UUID, actor_type: str, payload: dict | None = None, actor_id: uuid.UUID | None = None, event_id: str | None = None, timestamp: datetime | None = None, ) -> Event: """Create a new Event.""" return Event( id=event_id or str(uuid.uuid4()), type=event_type, timestamp=timestamp or datetime.now(UTC), project_id=project_id, actor_id=actor_id, actor_type=actor_type, payload=payload or {}, ) async def publish(self, channel: str, event: Event) -> int: """Publish an event to a channel.""" self.published_events.append(event) return 1 def add_event_to_yield(self, event_json: str) -> None: """Add an event JSON string to be yielded by subscribe_sse.""" self._events_to_yield.append(event_json) async def subscribe_sse( self, project_id: str | uuid.UUID, last_event_id: str | None = None, keepalive_interval: int = 30, ) -> AsyncGenerator[str, None]: """Mock subscribe_sse that yields pre-configured events then keepalive.""" # First yield any pre-configured events for event_data in self._events_to_yield: yield event_data # Then yield keepalive yield "" # Then stop to allow test to complete self._should_yield_events = False @pytest_asyncio.fixture async def mock_event_bus(): """Create a mock event bus for testing.""" return MockEventBus() @pytest_asyncio.fixture async def client_with_mock_bus(async_test_db, mock_event_bus): """ Create a FastAPI test client with mocked database and event bus. """ _test_engine, AsyncTestingSessionLocal = async_test_db async def override_get_db(): async with AsyncTestingSessionLocal() as session: try: yield session finally: pass async def override_get_event_bus(): return mock_event_bus app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_event_bus] = override_get_event_bus transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as test_client: yield test_client app.dependency_overrides.clear() @pytest_asyncio.fixture async def user_token_with_mock_bus(client_with_mock_bus, async_test_user): """Create an access token for the test user.""" response = await client_with_mock_bus.post( "/api/v1/auth/login", json={ "email": async_test_user.email, "password": "TestPassword123!", }, ) assert response.status_code == 200, f"Login failed: {response.text}" tokens = response.json() return tokens["access_token"] class TestSSEEndpointAuthentication: """Tests for SSE endpoint authentication.""" @pytest.mark.asyncio async def test_stream_events_requires_authentication(self, client_with_mock_bus): """Test that SSE endpoint requires authentication.""" project_id = uuid.uuid4() response = await client_with_mock_bus.get( f"/api/v1/projects/{project_id}/events/stream", ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.asyncio async def test_stream_events_with_invalid_token(self, client_with_mock_bus): """Test that SSE endpoint rejects invalid tokens.""" project_id = uuid.uuid4() response = await client_with_mock_bus.get( f"/api/v1/projects/{project_id}/events/stream", headers={"Authorization": "Bearer invalid_token"}, ) assert response.status_code == status.HTTP_401_UNAUTHORIZED class TestSSEEndpointStream: """Tests for SSE stream functionality.""" @pytest.mark.asyncio async def test_stream_events_returns_sse_response( self, client_with_mock_bus, user_token_with_mock_bus ): """Test that SSE endpoint returns proper SSE response.""" project_id = uuid.uuid4() # Make request with a timeout to avoid hanging response = await client_with_mock_bus.get( f"/api/v1/projects/{project_id}/events/stream", headers={"Authorization": f"Bearer {user_token_with_mock_bus}"}, timeout=5.0, ) # The response should start streaming assert response.status_code == status.HTTP_200_OK assert "text/event-stream" in response.headers.get("content-type", "") @pytest.mark.asyncio async def test_stream_events_with_events( self, client_with_mock_bus, user_token_with_mock_bus, mock_event_bus ): """Test that SSE endpoint yields events.""" project_id = uuid.uuid4() # Create a test event and add it to the mock bus test_event = Event( id=str(uuid.uuid4()), type=EventType.AGENT_MESSAGE, timestamp=datetime.now(UTC), project_id=project_id, actor_type="agent", payload={"message": "test"}, ) mock_event_bus.add_event_to_yield(test_event.model_dump_json()) # Request the stream response = await client_with_mock_bus.get( f"/api/v1/projects/{project_id}/events/stream", headers={"Authorization": f"Bearer {user_token_with_mock_bus}"}, timeout=5.0, ) assert response.status_code == status.HTTP_200_OK # Check response contains event data content = response.text assert "agent.message" in content or "data:" in content @pytest.mark.asyncio async def test_stream_events_with_last_event_id( self, client_with_mock_bus, user_token_with_mock_bus ): """Test that Last-Event-ID header is accepted.""" project_id = uuid.uuid4() last_event_id = str(uuid.uuid4()) response = await client_with_mock_bus.get( f"/api/v1/projects/{project_id}/events/stream", headers={ "Authorization": f"Bearer {user_token_with_mock_bus}", "Last-Event-ID": last_event_id, }, timeout=5.0, ) # Should accept the header and return OK assert response.status_code == status.HTTP_200_OK class TestSSEEndpointHeaders: """Tests for SSE response headers.""" @pytest.mark.asyncio async def test_stream_events_cache_control_header( self, client_with_mock_bus, user_token_with_mock_bus ): """Test that SSE response has no-cache header.""" project_id = uuid.uuid4() response = await client_with_mock_bus.get( f"/api/v1/projects/{project_id}/events/stream", headers={"Authorization": f"Bearer {user_token_with_mock_bus}"}, timeout=5.0, ) assert response.status_code == status.HTTP_200_OK cache_control = response.headers.get("cache-control", "") assert "no-cache" in cache_control.lower() class TestTestEventEndpoint: """Tests for the test event endpoint.""" @pytest.mark.asyncio async def test_send_test_event_requires_auth(self, client_with_mock_bus): """Test that test event endpoint requires authentication.""" project_id = uuid.uuid4() response = await client_with_mock_bus.post( f"/api/v1/projects/{project_id}/events/test", ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @pytest.mark.asyncio async def test_send_test_event_success( self, client_with_mock_bus, user_token_with_mock_bus, mock_event_bus ): """Test sending a test event.""" project_id = uuid.uuid4() response = await client_with_mock_bus.post( f"/api/v1/projects/{project_id}/events/test", headers={"Authorization": f"Bearer {user_token_with_mock_bus}"}, ) assert response.status_code == status.HTTP_200_OK data = response.json() assert data["success"] is True assert "event_id" in data assert data["event_type"] == "agent.message" # Verify event was published assert len(mock_event_bus.published_events) == 1 published = mock_event_bus.published_events[0] assert published.type == EventType.AGENT_MESSAGE assert published.project_id == project_id class TestEventSchema: """Tests for the Event schema.""" def test_event_creation(self): """Test Event creation with required fields.""" project_id = uuid.uuid4() event = Event( id=str(uuid.uuid4()), type=EventType.AGENT_MESSAGE, timestamp=datetime.now(UTC), project_id=project_id, actor_type="agent", payload={"message": "test"}, ) assert event.id is not None assert event.type == EventType.AGENT_MESSAGE assert event.project_id == project_id assert event.actor_type == "agent" assert event.payload == {"message": "test"} def test_event_json_serialization(self): """Test Event JSON serialization.""" project_id = uuid.uuid4() event = Event( id="test-id", type=EventType.AGENT_STATUS_CHANGED, timestamp=datetime.now(UTC), project_id=project_id, actor_type="system", payload={"status": "running"}, ) json_str = event.model_dump_json() parsed = json.loads(json_str) assert parsed["id"] == "test-id" assert parsed["type"] == "agent.status_changed" assert str(parsed["project_id"]) == str(project_id) assert parsed["payload"]["status"] == "running" class TestEventBusUnit: """Unit tests for EventBus class.""" @pytest.mark.asyncio async def test_event_bus_not_connected_raises(self): """Test that accessing redis_client before connect raises.""" from app.services.event_bus import EventBusConnectionError bus = EventBus() with pytest.raises(EventBusConnectionError, match="not connected"): _ = bus.redis_client @pytest.mark.asyncio async def test_event_bus_channel_names(self): """Test channel name generation.""" bus = EventBus() project_id = uuid.uuid4() agent_id = uuid.uuid4() user_id = uuid.uuid4() assert bus.get_project_channel(project_id) == f"project:{project_id}" assert bus.get_agent_channel(agent_id) == f"agent:{agent_id}" assert bus.get_user_channel(user_id) == f"user:{user_id}" def test_event_bus_create_event(self): """Test EventBus.create_event factory method.""" project_id = uuid.uuid4() actor_id = uuid.uuid4() event = EventBus.create_event( event_type=EventType.ISSUE_CREATED, project_id=project_id, actor_type="user", actor_id=actor_id, payload={"title": "Test Issue"}, ) assert event.type == EventType.ISSUE_CREATED assert event.project_id == project_id assert event.actor_id == actor_id assert event.actor_type == "user" assert event.payload == {"title": "Test Issue"} class TestEventBusIntegration: """Integration tests for EventBus with mocked Redis.""" @pytest.mark.asyncio async def test_event_bus_connect_disconnect(self): """Test EventBus connect and disconnect.""" with patch("app.services.event_bus.redis.from_url") as mock_redis: mock_client = AsyncMock() mock_redis.return_value = mock_client mock_client.ping = AsyncMock() mock_client.pubsub = lambda: AsyncMock() bus = EventBus(redis_url="redis://localhost:6379/0") # Connect await bus.connect() mock_client.ping.assert_called_once() assert bus._redis_client is not None assert bus.is_connected # Disconnect await bus.disconnect() mock_client.aclose.assert_called_once() assert bus._redis_client is None assert not bus.is_connected @pytest.mark.asyncio async def test_event_bus_publish(self): """Test EventBus event publishing.""" with patch("app.services.event_bus.redis.from_url") as mock_redis: mock_client = AsyncMock() mock_redis.return_value = mock_client mock_client.ping = AsyncMock() mock_client.publish = AsyncMock(return_value=1) mock_client.pubsub = lambda: AsyncMock() bus = EventBus() await bus.connect() project_id = uuid.uuid4() event = EventBus.create_event( event_type=EventType.AGENT_SPAWNED, project_id=project_id, actor_type="system", payload={"agent_name": "test-agent"}, ) channel = bus.get_project_channel(project_id) result = await bus.publish(channel, event) # Verify publish was called mock_client.publish.assert_called_once() call_args = mock_client.publish.call_args # Check channel name assert call_args[0][0] == f"project:{project_id}" # Check result assert result == 1 await bus.disconnect() @pytest.mark.asyncio async def test_event_bus_connect_failure(self): """Test EventBus handles connection failure.""" from app.services.event_bus import EventBusConnectionError with patch("app.services.event_bus.redis.from_url") as mock_redis: mock_client = AsyncMock() mock_redis.return_value = mock_client import redis.asyncio as redis_async mock_client.ping = AsyncMock( side_effect=redis_async.ConnectionError("Connection refused") ) bus = EventBus() with pytest.raises(EventBusConnectionError, match="Failed to connect"): await bus.connect() @pytest.mark.asyncio async def test_event_bus_already_connected(self): """Test EventBus connect when already connected is a no-op.""" with patch("app.services.event_bus.redis.from_url") as mock_redis: mock_client = AsyncMock() mock_redis.return_value = mock_client mock_client.ping = AsyncMock() mock_client.pubsub = lambda: AsyncMock() bus = EventBus() # First connect await bus.connect() assert mock_client.ping.call_count == 1 # Second connect should be a no-op await bus.connect() assert mock_client.ping.call_count == 1 await bus.disconnect()