Reformatted multiline function calls, object definitions, and queries for improved code readability and consistency. Adjusted imports and constraints where necessary.
395 lines
13 KiB
Python
395 lines
13 KiB
Python
# app/crud/syndarix/agent_instance.py
|
|
"""Async CRUD operations for AgentInstance model using SQLAlchemy 2.0 patterns."""
|
|
|
|
import logging
|
|
from datetime import UTC, datetime
|
|
from decimal import Decimal
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import func, select, update
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
from app.crud.base import CRUDBase
|
|
from app.models.syndarix import AgentInstance, Issue
|
|
from app.models.syndarix.enums import AgentStatus
|
|
from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CRUDAgentInstance(
|
|
CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]
|
|
):
|
|
"""Async CRUD operations for AgentInstance model."""
|
|
|
|
async def create(
|
|
self, db: AsyncSession, *, obj_in: AgentInstanceCreate
|
|
) -> AgentInstance:
|
|
"""Create a new agent instance with error handling."""
|
|
try:
|
|
db_obj = AgentInstance(
|
|
agent_type_id=obj_in.agent_type_id,
|
|
project_id=obj_in.project_id,
|
|
name=obj_in.name,
|
|
status=obj_in.status,
|
|
current_task=obj_in.current_task,
|
|
short_term_memory=obj_in.short_term_memory,
|
|
long_term_memory_ref=obj_in.long_term_memory_ref,
|
|
session_id=obj_in.session_id,
|
|
)
|
|
db.add(db_obj)
|
|
await db.commit()
|
|
await db.refresh(db_obj)
|
|
return db_obj
|
|
except IntegrityError as e:
|
|
await db.rollback()
|
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
|
logger.error(f"Integrity error creating agent instance: {error_msg}")
|
|
raise ValueError(f"Database integrity error: {error_msg}")
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Unexpected error creating agent instance: {e!s}", exc_info=True
|
|
)
|
|
raise
|
|
|
|
async def get_with_details(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
instance_id: UUID,
|
|
) -> dict[str, Any] | None:
|
|
"""
|
|
Get an agent instance with full details including related entities.
|
|
|
|
Returns:
|
|
Dictionary with instance and related entity details
|
|
"""
|
|
try:
|
|
# Get instance with joined relationships
|
|
result = await db.execute(
|
|
select(AgentInstance)
|
|
.options(
|
|
joinedload(AgentInstance.agent_type),
|
|
joinedload(AgentInstance.project),
|
|
)
|
|
.where(AgentInstance.id == instance_id)
|
|
)
|
|
instance = result.scalar_one_or_none()
|
|
|
|
if not instance:
|
|
return None
|
|
|
|
# Get assigned issues count
|
|
issues_count_result = await db.execute(
|
|
select(func.count(Issue.id)).where(
|
|
Issue.assigned_agent_id == instance_id
|
|
)
|
|
)
|
|
assigned_issues_count = issues_count_result.scalar_one()
|
|
|
|
return {
|
|
"instance": instance,
|
|
"agent_type_name": instance.agent_type.name
|
|
if instance.agent_type
|
|
else None,
|
|
"agent_type_slug": instance.agent_type.slug
|
|
if instance.agent_type
|
|
else None,
|
|
"project_name": instance.project.name if instance.project else None,
|
|
"project_slug": instance.project.slug if instance.project else None,
|
|
"assigned_issues_count": assigned_issues_count,
|
|
}
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting agent instance with details {instance_id}: {e!s}",
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
|
|
async def get_by_project(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
project_id: UUID,
|
|
status: AgentStatus | None = None,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
) -> tuple[list[AgentInstance], int]:
|
|
"""Get agent instances for a specific project."""
|
|
try:
|
|
query = select(AgentInstance).where(AgentInstance.project_id == project_id)
|
|
|
|
if status is not None:
|
|
query = query.where(AgentInstance.status == status)
|
|
|
|
# Get total count
|
|
count_query = select(func.count()).select_from(query.alias())
|
|
count_result = await db.execute(count_query)
|
|
total = count_result.scalar_one()
|
|
|
|
# Apply pagination
|
|
query = query.order_by(AgentInstance.created_at.desc())
|
|
query = query.offset(skip).limit(limit)
|
|
result = await db.execute(query)
|
|
instances = list(result.scalars().all())
|
|
|
|
return instances, total
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting instances by project {project_id}: {e!s}",
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
|
|
async def get_by_agent_type(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
agent_type_id: UUID,
|
|
status: AgentStatus | None = None,
|
|
) -> list[AgentInstance]:
|
|
"""Get all instances of a specific agent type."""
|
|
try:
|
|
query = select(AgentInstance).where(
|
|
AgentInstance.agent_type_id == agent_type_id
|
|
)
|
|
|
|
if status is not None:
|
|
query = query.where(AgentInstance.status == status)
|
|
|
|
query = query.order_by(AgentInstance.created_at.desc())
|
|
result = await db.execute(query)
|
|
return list(result.scalars().all())
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting instances by agent type {agent_type_id}: {e!s}",
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
|
|
async def update_status(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
instance_id: UUID,
|
|
status: AgentStatus,
|
|
current_task: str | None = None,
|
|
) -> AgentInstance | None:
|
|
"""Update the status of an agent instance."""
|
|
try:
|
|
result = await db.execute(
|
|
select(AgentInstance).where(AgentInstance.id == instance_id)
|
|
)
|
|
instance = result.scalar_one_or_none()
|
|
|
|
if not instance:
|
|
return None
|
|
|
|
instance.status = status
|
|
instance.last_activity_at = datetime.now(UTC)
|
|
if current_task is not None:
|
|
instance.current_task = current_task
|
|
|
|
await db.commit()
|
|
await db.refresh(instance)
|
|
return instance
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error updating instance status {instance_id}: {e!s}", exc_info=True
|
|
)
|
|
raise
|
|
|
|
async def terminate(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
instance_id: UUID,
|
|
) -> AgentInstance | None:
|
|
"""Terminate an agent instance.
|
|
|
|
Also unassigns all issues from this agent to prevent orphaned assignments.
|
|
"""
|
|
try:
|
|
result = await db.execute(
|
|
select(AgentInstance).where(AgentInstance.id == instance_id)
|
|
)
|
|
instance = result.scalar_one_or_none()
|
|
|
|
if not instance:
|
|
return None
|
|
|
|
# Unassign all issues from this agent before terminating
|
|
await db.execute(
|
|
update(Issue)
|
|
.where(Issue.assigned_agent_id == instance_id)
|
|
.values(assigned_agent_id=None)
|
|
)
|
|
|
|
instance.status = AgentStatus.TERMINATED
|
|
instance.terminated_at = datetime.now(UTC)
|
|
instance.current_task = None
|
|
instance.session_id = None
|
|
|
|
await db.commit()
|
|
await db.refresh(instance)
|
|
return instance
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error terminating instance {instance_id}: {e!s}", exc_info=True
|
|
)
|
|
raise
|
|
|
|
async def record_task_completion(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
instance_id: UUID,
|
|
tokens_used: int,
|
|
cost_incurred: Decimal,
|
|
) -> AgentInstance | None:
|
|
"""Record a completed task and update metrics.
|
|
|
|
Uses atomic SQL UPDATE to prevent lost updates under concurrent load.
|
|
This avoids the read-modify-write race condition that occurs when
|
|
multiple task completions happen simultaneously.
|
|
"""
|
|
try:
|
|
now = datetime.now(UTC)
|
|
|
|
# Use atomic SQL UPDATE to increment counters without race conditions
|
|
# This is safe for concurrent updates - no read-modify-write pattern
|
|
result = await db.execute(
|
|
update(AgentInstance)
|
|
.where(AgentInstance.id == instance_id)
|
|
.values(
|
|
tasks_completed=AgentInstance.tasks_completed + 1,
|
|
tokens_used=AgentInstance.tokens_used + tokens_used,
|
|
cost_incurred=AgentInstance.cost_incurred + cost_incurred,
|
|
last_activity_at=now,
|
|
updated_at=now,
|
|
)
|
|
.returning(AgentInstance)
|
|
)
|
|
instance = result.scalar_one_or_none()
|
|
|
|
if not instance:
|
|
return None
|
|
|
|
await db.commit()
|
|
return instance
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error recording task completion {instance_id}: {e!s}", exc_info=True
|
|
)
|
|
raise
|
|
|
|
async def get_project_metrics(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
project_id: UUID,
|
|
) -> dict[str, Any]:
|
|
"""Get aggregated metrics for all agents in a project."""
|
|
try:
|
|
result = await db.execute(
|
|
select(
|
|
func.count(AgentInstance.id).label("total_instances"),
|
|
func.count(AgentInstance.id)
|
|
.filter(AgentInstance.status == AgentStatus.WORKING)
|
|
.label("active_instances"),
|
|
func.count(AgentInstance.id)
|
|
.filter(AgentInstance.status == AgentStatus.IDLE)
|
|
.label("idle_instances"),
|
|
func.sum(AgentInstance.tasks_completed).label("total_tasks"),
|
|
func.sum(AgentInstance.tokens_used).label("total_tokens"),
|
|
func.sum(AgentInstance.cost_incurred).label("total_cost"),
|
|
).where(AgentInstance.project_id == project_id)
|
|
)
|
|
row = result.one()
|
|
|
|
return {
|
|
"total_instances": row.total_instances or 0,
|
|
"active_instances": row.active_instances or 0,
|
|
"idle_instances": row.idle_instances or 0,
|
|
"total_tasks_completed": row.total_tasks or 0,
|
|
"total_tokens_used": row.total_tokens or 0,
|
|
"total_cost_incurred": row.total_cost or Decimal("0.0000"),
|
|
}
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting project metrics {project_id}: {e!s}", exc_info=True
|
|
)
|
|
raise
|
|
|
|
async def bulk_terminate_by_project(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
project_id: UUID,
|
|
) -> int:
|
|
"""Terminate all active instances in a project.
|
|
|
|
Also unassigns all issues from these agents to prevent orphaned assignments.
|
|
"""
|
|
try:
|
|
# First, unassign all issues from agents in this project
|
|
# Get all agent IDs that will be terminated
|
|
agents_to_terminate = await db.execute(
|
|
select(AgentInstance.id).where(
|
|
AgentInstance.project_id == project_id,
|
|
AgentInstance.status != AgentStatus.TERMINATED,
|
|
)
|
|
)
|
|
agent_ids = [row[0] for row in agents_to_terminate.fetchall()]
|
|
|
|
# Unassign issues from these agents
|
|
if agent_ids:
|
|
await db.execute(
|
|
update(Issue)
|
|
.where(Issue.assigned_agent_id.in_(agent_ids))
|
|
.values(assigned_agent_id=None)
|
|
)
|
|
|
|
now = datetime.now(UTC)
|
|
stmt = (
|
|
update(AgentInstance)
|
|
.where(
|
|
AgentInstance.project_id == project_id,
|
|
AgentInstance.status != AgentStatus.TERMINATED,
|
|
)
|
|
.values(
|
|
status=AgentStatus.TERMINATED,
|
|
terminated_at=now,
|
|
current_task=None,
|
|
session_id=None,
|
|
updated_at=now,
|
|
)
|
|
)
|
|
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
terminated_count = result.rowcount
|
|
logger.info(
|
|
f"Bulk terminated {terminated_count} instances in project {project_id}"
|
|
)
|
|
return terminated_count
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error bulk terminating instances for project {project_id}: {e!s}",
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
|
|
|
|
# Create a singleton instance for use across the application
|
|
agent_instance = CRUDAgentInstance(AgentInstance)
|