# tests/crud/syndarix/test_agent_instance.py """Tests for AgentInstance CRUD operations.""" import uuid from decimal import Decimal from unittest.mock import patch import pytest import pytest_asyncio from sqlalchemy.exc import IntegrityError, OperationalError from app.crud.syndarix.agent_instance import agent_instance from app.models.syndarix import AgentInstance, AgentType, Project from app.models.syndarix.enums import ( AgentStatus, ProjectStatus, ) from app.schemas.syndarix import AgentInstanceCreate @pytest_asyncio.fixture async def db_session(async_test_db): """Create a database session for tests.""" _, AsyncTestingSessionLocal = async_test_db async with AsyncTestingSessionLocal() as session: yield session @pytest_asyncio.fixture async def test_project(db_session): """Create a test project.""" project = Project( id=uuid.uuid4(), name="Test Project", slug=f"test-project-{uuid.uuid4().hex[:8]}", status=ProjectStatus.ACTIVE, ) db_session.add(project) await db_session.commit() await db_session.refresh(project) return project @pytest_asyncio.fixture async def test_agent_type(db_session): """Create a test agent type.""" agent_type = AgentType( id=uuid.uuid4(), name="Test Agent Type", slug=f"test-agent-type-{uuid.uuid4().hex[:8]}", primary_model="claude-3-opus", personality_prompt="You are a helpful test agent.", ) db_session.add(agent_type) await db_session.commit() await db_session.refresh(agent_type) return agent_type @pytest_asyncio.fixture async def test_agent_instance(db_session, test_project, test_agent_type): """Create a test agent instance.""" instance = AgentInstance( id=uuid.uuid4(), agent_type_id=test_agent_type.id, project_id=test_project.id, name="Test Agent", status=AgentStatus.IDLE, ) db_session.add(instance) await db_session.commit() await db_session.refresh(instance) return instance class TestAgentInstanceCreate: """Tests for agent instance creation.""" @pytest.mark.asyncio async def test_create_instance_success( self, db_session, test_project, test_agent_type ): """Test successful agent instance creation.""" instance_data = AgentInstanceCreate( agent_type_id=test_agent_type.id, project_id=test_project.id, name="New Agent", ) created = await agent_instance.create(db_session, obj_in=instance_data) assert created.name == "New Agent" assert created.status == AgentStatus.IDLE @pytest.mark.asyncio async def test_create_instance_with_all_fields( self, db_session, test_project, test_agent_type ): """Test agent instance creation with all optional fields.""" instance_data = AgentInstanceCreate( agent_type_id=test_agent_type.id, project_id=test_project.id, name="Full Agent", status=AgentStatus.WORKING, current_task="Processing request", short_term_memory={"context": "test context", "history": []}, long_term_memory_ref="ref-123", session_id="session-456", ) created = await agent_instance.create(db_session, obj_in=instance_data) assert created.current_task == "Processing request" assert created.status == AgentStatus.WORKING @pytest.mark.asyncio async def test_create_instance_integrity_error( self, db_session, test_project, test_agent_type ): """Test agent instance creation with integrity error.""" instance_data = AgentInstanceCreate( agent_type_id=test_agent_type.id, project_id=test_project.id, name="Test Agent", ) with patch.object( db_session, "commit", side_effect=IntegrityError("", {}, Exception()), ): with pytest.raises(ValueError, match="Database integrity error"): await agent_instance.create(db_session, obj_in=instance_data) @pytest.mark.asyncio async def test_create_instance_unexpected_error( self, db_session, test_project, test_agent_type ): """Test agent instance creation with unexpected error.""" instance_data = AgentInstanceCreate( agent_type_id=test_agent_type.id, project_id=test_project.id, name="Test Agent", ) with patch.object( db_session, "commit", side_effect=RuntimeError("Unexpected error"), ): with pytest.raises(RuntimeError, match="Unexpected error"): await agent_instance.create(db_session, obj_in=instance_data) class TestAgentInstanceGetWithDetails: """Tests for getting agent instance with details.""" @pytest.mark.asyncio async def test_get_with_details_not_found(self, db_session): """Test getting non-existent agent instance with details.""" result = await agent_instance.get_with_details( db_session, instance_id=uuid.uuid4() ) assert result is None @pytest.mark.asyncio async def test_get_with_details_success(self, db_session, test_agent_instance): """Test getting agent instance with details.""" result = await agent_instance.get_with_details( db_session, instance_id=test_agent_instance.id ) assert result is not None assert result["instance"].id == test_agent_instance.id assert "agent_type_name" in result assert "assigned_issues_count" in result @pytest.mark.asyncio async def test_get_with_details_db_error(self, db_session, test_agent_instance): """Test getting agent instance with details when DB error occurs.""" with patch.object( db_session, "execute", side_effect=OperationalError("Connection lost", {}, Exception()), ): with pytest.raises(OperationalError): await agent_instance.get_with_details( db_session, instance_id=test_agent_instance.id ) class TestAgentInstanceGetByProject: """Tests for getting agent instances by project.""" @pytest.mark.asyncio async def test_get_by_project_success( self, db_session, test_project, test_agent_instance ): """Test getting agent instances by project.""" instances, total = await agent_instance.get_by_project( db_session, project_id=test_project.id ) assert len(instances) == 1 assert total == 1 @pytest.mark.asyncio async def test_get_by_project_with_status_filter( self, db_session, test_project, test_agent_instance ): """Test getting agent instances with status filter.""" instances, total = await agent_instance.get_by_project( db_session, project_id=test_project.id, status=AgentStatus.IDLE, ) assert len(instances) == 1 assert instances[0].status == AgentStatus.IDLE @pytest.mark.asyncio async def test_get_by_project_db_error(self, db_session, test_project): """Test getting agent instances when DB error occurs.""" with patch.object( db_session, "execute", side_effect=OperationalError("Connection lost", {}, Exception()), ): with pytest.raises(OperationalError): await agent_instance.get_by_project( db_session, project_id=test_project.id ) class TestAgentInstanceGetByAgentType: """Tests for getting agent instances by agent type.""" @pytest.mark.asyncio async def test_get_by_agent_type_success( self, db_session, test_agent_type, test_agent_instance ): """Test getting agent instances by agent type.""" instances = await agent_instance.get_by_agent_type( db_session, agent_type_id=test_agent_type.id ) assert len(instances) == 1 @pytest.mark.asyncio async def test_get_by_agent_type_with_status_filter( self, db_session, test_agent_type, test_agent_instance ): """Test getting agent instances by agent type with status filter.""" instances = await agent_instance.get_by_agent_type( db_session, agent_type_id=test_agent_type.id, status=AgentStatus.IDLE, ) assert len(instances) == 1 assert instances[0].status == AgentStatus.IDLE @pytest.mark.asyncio async def test_get_by_agent_type_db_error(self, db_session, test_agent_type): """Test getting agent instances by agent type when DB error occurs.""" with patch.object( db_session, "execute", side_effect=OperationalError("Connection lost", {}, Exception()), ): with pytest.raises(OperationalError): await agent_instance.get_by_agent_type( db_session, agent_type_id=test_agent_type.id ) class TestAgentInstanceStatusOperations: """Tests for agent instance status operations.""" @pytest.mark.asyncio async def test_update_status_not_found(self, db_session): """Test updating status for non-existent agent instance.""" result = await agent_instance.update_status( db_session, instance_id=uuid.uuid4(), status=AgentStatus.WORKING, ) assert result is None @pytest.mark.asyncio async def test_update_status_success(self, db_session, test_agent_instance): """Test successfully updating agent instance status.""" result = await agent_instance.update_status( db_session, instance_id=test_agent_instance.id, status=AgentStatus.WORKING, current_task="Processing task", ) assert result is not None assert result.status == AgentStatus.WORKING assert result.current_task == "Processing task" @pytest.mark.asyncio async def test_update_status_db_error(self, db_session, test_agent_instance): """Test updating status when DB error occurs.""" with patch.object( db_session, "commit", side_effect=OperationalError("Connection lost", {}, Exception()), ): with pytest.raises(OperationalError): await agent_instance.update_status( db_session, instance_id=test_agent_instance.id, status=AgentStatus.WORKING, ) class TestAgentInstanceTerminate: """Tests for agent instance termination.""" @pytest.mark.asyncio async def test_terminate_not_found(self, db_session): """Test terminating non-existent agent instance.""" result = await agent_instance.terminate(db_session, instance_id=uuid.uuid4()) assert result is None @pytest.mark.asyncio async def test_terminate_success(self, db_session, test_agent_instance): """Test successfully terminating agent instance.""" result = await agent_instance.terminate( db_session, instance_id=test_agent_instance.id ) assert result is not None assert result.status == AgentStatus.TERMINATED assert result.terminated_at is not None @pytest.mark.asyncio async def test_terminate_db_error(self, db_session, test_agent_instance): """Test terminating agent instance when DB error occurs.""" with patch.object( db_session, "commit", side_effect=OperationalError("Connection lost", {}, Exception()), ): with pytest.raises(OperationalError): await agent_instance.terminate( db_session, instance_id=test_agent_instance.id ) class TestAgentInstanceTaskCompletion: """Tests for recording task completion.""" @pytest.mark.asyncio async def test_record_task_completion_not_found(self, db_session): """Test recording task completion for non-existent agent instance.""" result = await agent_instance.record_task_completion( db_session, instance_id=uuid.uuid4(), tokens_used=100, cost_incurred=Decimal("0.01"), ) assert result is None @pytest.mark.asyncio async def test_record_task_completion_success( self, db_session, test_agent_instance ): """Test successfully recording task completion.""" result = await agent_instance.record_task_completion( db_session, instance_id=test_agent_instance.id, tokens_used=1000, cost_incurred=Decimal("0.05"), ) assert result is not None assert result.tasks_completed == 1 assert result.tokens_used == 1000 @pytest.mark.asyncio async def test_record_task_completion_db_error( self, db_session, test_agent_instance ): """Test recording task completion when DB error occurs.""" with patch.object( db_session, "commit", side_effect=OperationalError("Connection lost", {}, Exception()), ): with pytest.raises(OperationalError): await agent_instance.record_task_completion( db_session, instance_id=test_agent_instance.id, tokens_used=100, cost_incurred=Decimal("0.01"), ) class TestAgentInstanceMetrics: """Tests for agent instance metrics.""" @pytest.mark.asyncio async def test_get_project_metrics_empty(self, db_session, test_project): """Test getting project metrics with no agent instances.""" result = await agent_instance.get_project_metrics( db_session, project_id=test_project.id ) assert result["total_instances"] == 0 assert result["active_instances"] == 0 @pytest.mark.asyncio async def test_get_project_metrics_with_data( self, db_session, test_project, test_agent_instance ): """Test getting project metrics with agent instances.""" result = await agent_instance.get_project_metrics( db_session, project_id=test_project.id ) assert result["total_instances"] == 1 assert result["idle_instances"] == 1 @pytest.mark.asyncio async def test_get_project_metrics_db_error(self, db_session, test_project): """Test getting project metrics when DB error occurs.""" with patch.object( db_session, "execute", side_effect=OperationalError("Connection lost", {}, Exception()), ): with pytest.raises(OperationalError): await agent_instance.get_project_metrics( db_session, project_id=test_project.id ) class TestAgentInstanceBulkTerminate: """Tests for bulk termination.""" @pytest.mark.asyncio async def test_bulk_terminate_by_project_empty(self, db_session, test_project): """Test bulk terminating with no agent instances.""" count = await agent_instance.bulk_terminate_by_project( db_session, project_id=test_project.id ) assert count == 0 @pytest.mark.asyncio async def test_bulk_terminate_by_project_success( self, db_session, test_project, test_agent_instance, test_agent_type ): """Test successfully bulk terminating agent instances.""" # Create another active instance instance2 = AgentInstance( id=uuid.uuid4(), agent_type_id=test_agent_type.id, project_id=test_project.id, name="Test Agent 2", status=AgentStatus.WORKING, ) db_session.add(instance2) await db_session.commit() count = await agent_instance.bulk_terminate_by_project( db_session, project_id=test_project.id ) assert count == 2 @pytest.mark.asyncio async def test_bulk_terminate_by_project_db_error( self, db_session, test_project, test_agent_instance ): """Test bulk terminating when DB error occurs.""" with patch.object( db_session, "execute", side_effect=OperationalError("Connection lost", {}, Exception()), ): with pytest.raises(OperationalError): await agent_instance.bulk_terminate_by_project( db_session, project_id=test_project.id )