# app/crud/syndarix/agent_instance.py """Async CRUD operations for AgentInstance model using SQLAlchemy 2.0 patterns.""" import logging from datetime import UTC, datetime from decimal import Decimal from typing import Any from uuid import UUID from sqlalchemy import func, select, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from app.crud.base import CRUDBase from app.models.syndarix import AgentInstance, Issue from app.models.syndarix.enums import AgentStatus from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate logger = logging.getLogger(__name__) class CRUDAgentInstance(CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]): """Async CRUD operations for AgentInstance model.""" async def create( self, db: AsyncSession, *, obj_in: AgentInstanceCreate ) -> AgentInstance: """Create a new agent instance with error handling.""" try: db_obj = AgentInstance( agent_type_id=obj_in.agent_type_id, project_id=obj_in.project_id, name=obj_in.name, status=obj_in.status, current_task=obj_in.current_task, short_term_memory=obj_in.short_term_memory, long_term_memory_ref=obj_in.long_term_memory_ref, session_id=obj_in.session_id, ) db.add(db_obj) await db.commit() await db.refresh(db_obj) return db_obj except IntegrityError as e: await db.rollback() error_msg = str(e.orig) if hasattr(e, "orig") else str(e) logger.error(f"Integrity error creating agent instance: {error_msg}") raise ValueError(f"Database integrity error: {error_msg}") except Exception as e: await db.rollback() logger.error( f"Unexpected error creating agent instance: {e!s}", exc_info=True ) raise async def get_with_details( self, db: AsyncSession, *, instance_id: UUID, ) -> dict[str, Any] | None: """ Get an agent instance with full details including related entities. Returns: Dictionary with instance and related entity details """ try: # Get instance with joined relationships result = await db.execute( select(AgentInstance) .options( joinedload(AgentInstance.agent_type), joinedload(AgentInstance.project), ) .where(AgentInstance.id == instance_id) ) instance = result.scalar_one_or_none() if not instance: return None # Get assigned issues count issues_count_result = await db.execute( select(func.count(Issue.id)).where( Issue.assigned_agent_id == instance_id ) ) assigned_issues_count = issues_count_result.scalar_one() return { "instance": instance, "agent_type_name": instance.agent_type.name if instance.agent_type else None, "agent_type_slug": instance.agent_type.slug if instance.agent_type else None, "project_name": instance.project.name if instance.project else None, "project_slug": instance.project.slug if instance.project else None, "assigned_issues_count": assigned_issues_count, } except Exception as e: logger.error( f"Error getting agent instance with details {instance_id}: {e!s}", exc_info=True, ) raise async def get_by_project( self, db: AsyncSession, *, project_id: UUID, status: AgentStatus | None = None, skip: int = 0, limit: int = 100, ) -> tuple[list[AgentInstance], int]: """Get agent instances for a specific project.""" try: query = select(AgentInstance).where( AgentInstance.project_id == project_id ) if status is not None: query = query.where(AgentInstance.status == status) # Get total count count_query = select(func.count()).select_from(query.alias()) count_result = await db.execute(count_query) total = count_result.scalar_one() # Apply pagination query = query.order_by(AgentInstance.created_at.desc()) query = query.offset(skip).limit(limit) result = await db.execute(query) instances = list(result.scalars().all()) return instances, total except Exception as e: logger.error( f"Error getting instances by project {project_id}: {e!s}", exc_info=True, ) raise async def get_by_agent_type( self, db: AsyncSession, *, agent_type_id: UUID, status: AgentStatus | None = None, ) -> list[AgentInstance]: """Get all instances of a specific agent type.""" try: query = select(AgentInstance).where( AgentInstance.agent_type_id == agent_type_id ) if status is not None: query = query.where(AgentInstance.status == status) query = query.order_by(AgentInstance.created_at.desc()) result = await db.execute(query) return list(result.scalars().all()) except Exception as e: logger.error( f"Error getting instances by agent type {agent_type_id}: {e!s}", exc_info=True, ) raise async def update_status( self, db: AsyncSession, *, instance_id: UUID, status: AgentStatus, current_task: str | None = None, ) -> AgentInstance | None: """Update the status of an agent instance.""" try: result = await db.execute( select(AgentInstance).where(AgentInstance.id == instance_id) ) instance = result.scalar_one_or_none() if not instance: return None instance.status = status instance.last_activity_at = datetime.now(UTC) if current_task is not None: instance.current_task = current_task await db.commit() await db.refresh(instance) return instance except Exception as e: await db.rollback() logger.error( f"Error updating instance status {instance_id}: {e!s}", exc_info=True ) raise async def terminate( self, db: AsyncSession, *, instance_id: UUID, ) -> AgentInstance | None: """Terminate an agent instance. Also unassigns all issues from this agent to prevent orphaned assignments. """ try: result = await db.execute( select(AgentInstance).where(AgentInstance.id == instance_id) ) instance = result.scalar_one_or_none() if not instance: return None # Unassign all issues from this agent before terminating await db.execute( update(Issue) .where(Issue.assigned_agent_id == instance_id) .values(assigned_agent_id=None) ) instance.status = AgentStatus.TERMINATED instance.terminated_at = datetime.now(UTC) instance.current_task = None instance.session_id = None await db.commit() await db.refresh(instance) return instance except Exception as e: await db.rollback() logger.error( f"Error terminating instance {instance_id}: {e!s}", exc_info=True ) raise async def record_task_completion( self, db: AsyncSession, *, instance_id: UUID, tokens_used: int, cost_incurred: Decimal, ) -> AgentInstance | None: """Record a completed task and update metrics. Uses atomic SQL UPDATE to prevent lost updates under concurrent load. This avoids the read-modify-write race condition that occurs when multiple task completions happen simultaneously. """ try: now = datetime.now(UTC) # Use atomic SQL UPDATE to increment counters without race conditions # This is safe for concurrent updates - no read-modify-write pattern result = await db.execute( update(AgentInstance) .where(AgentInstance.id == instance_id) .values( tasks_completed=AgentInstance.tasks_completed + 1, tokens_used=AgentInstance.tokens_used + tokens_used, cost_incurred=AgentInstance.cost_incurred + cost_incurred, last_activity_at=now, updated_at=now, ) .returning(AgentInstance) ) instance = result.scalar_one_or_none() if not instance: return None await db.commit() return instance except Exception as e: await db.rollback() logger.error( f"Error recording task completion {instance_id}: {e!s}", exc_info=True ) raise async def get_project_metrics( self, db: AsyncSession, *, project_id: UUID, ) -> dict[str, Any]: """Get aggregated metrics for all agents in a project.""" try: result = await db.execute( select( func.count(AgentInstance.id).label("total_instances"), func.count(AgentInstance.id) .filter(AgentInstance.status == AgentStatus.WORKING) .label("active_instances"), func.count(AgentInstance.id) .filter(AgentInstance.status == AgentStatus.IDLE) .label("idle_instances"), func.sum(AgentInstance.tasks_completed).label("total_tasks"), func.sum(AgentInstance.tokens_used).label("total_tokens"), func.sum(AgentInstance.cost_incurred).label("total_cost"), ).where(AgentInstance.project_id == project_id) ) row = result.one() return { "total_instances": row.total_instances or 0, "active_instances": row.active_instances or 0, "idle_instances": row.idle_instances or 0, "total_tasks_completed": row.total_tasks or 0, "total_tokens_used": row.total_tokens or 0, "total_cost_incurred": row.total_cost or Decimal("0.0000"), } except Exception as e: logger.error( f"Error getting project metrics {project_id}: {e!s}", exc_info=True ) raise async def bulk_terminate_by_project( self, db: AsyncSession, *, project_id: UUID, ) -> int: """Terminate all active instances in a project. Also unassigns all issues from these agents to prevent orphaned assignments. """ try: # First, unassign all issues from agents in this project # Get all agent IDs that will be terminated agents_to_terminate = await db.execute( select(AgentInstance.id).where( AgentInstance.project_id == project_id, AgentInstance.status != AgentStatus.TERMINATED, ) ) agent_ids = [row[0] for row in agents_to_terminate.fetchall()] # Unassign issues from these agents if agent_ids: await db.execute( update(Issue) .where(Issue.assigned_agent_id.in_(agent_ids)) .values(assigned_agent_id=None) ) now = datetime.now(UTC) stmt = ( update(AgentInstance) .where( AgentInstance.project_id == project_id, AgentInstance.status != AgentStatus.TERMINATED, ) .values( status=AgentStatus.TERMINATED, terminated_at=now, current_task=None, session_id=None, updated_at=now, ) ) result = await db.execute(stmt) await db.commit() terminated_count = result.rowcount logger.info( f"Bulk terminated {terminated_count} instances in project {project_id}" ) return terminated_count except Exception as e: await db.rollback() logger.error( f"Error bulk terminating instances for project {project_id}: {e!s}", exc_info=True, ) raise # Create a singleton instance for use across the application agent_instance = CRUDAgentInstance(AgentInstance)