Files
fast-next-template/backend/app/crud/syndarix/agent_instance.py
Felipe Cardoso acd18ff694 chore(backend): standardize multiline formatting across modules
Reformatted multiline function calls, object definitions, and queries for improved code readability and consistency. Adjusted imports and constraints where necessary.
2026-01-03 01:35:18 +01:00

395 lines
13 KiB
Python

# app/crud/syndarix/agent_instance.py
"""Async CRUD operations for AgentInstance model using SQLAlchemy 2.0 patterns."""
import logging
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any
from uuid import UUID
from sqlalchemy import func, select, update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from app.crud.base import CRUDBase
from app.models.syndarix import AgentInstance, Issue
from app.models.syndarix.enums import AgentStatus
from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
logger = logging.getLogger(__name__)
class CRUDAgentInstance(
CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]
):
"""Async CRUD operations for AgentInstance model."""
async def create(
self, db: AsyncSession, *, obj_in: AgentInstanceCreate
) -> AgentInstance:
"""Create a new agent instance with error handling."""
try:
db_obj = AgentInstance(
agent_type_id=obj_in.agent_type_id,
project_id=obj_in.project_id,
name=obj_in.name,
status=obj_in.status,
current_task=obj_in.current_task,
short_term_memory=obj_in.short_term_memory,
long_term_memory_ref=obj_in.long_term_memory_ref,
session_id=obj_in.session_id,
)
db.add(db_obj)
await db.commit()
await db.refresh(db_obj)
return db_obj
except IntegrityError as e:
await db.rollback()
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
logger.error(f"Integrity error creating agent instance: {error_msg}")
raise ValueError(f"Database integrity error: {error_msg}")
except Exception as e:
await db.rollback()
logger.error(
f"Unexpected error creating agent instance: {e!s}", exc_info=True
)
raise
async def get_with_details(
self,
db: AsyncSession,
*,
instance_id: UUID,
) -> dict[str, Any] | None:
"""
Get an agent instance with full details including related entities.
Returns:
Dictionary with instance and related entity details
"""
try:
# Get instance with joined relationships
result = await db.execute(
select(AgentInstance)
.options(
joinedload(AgentInstance.agent_type),
joinedload(AgentInstance.project),
)
.where(AgentInstance.id == instance_id)
)
instance = result.scalar_one_or_none()
if not instance:
return None
# Get assigned issues count
issues_count_result = await db.execute(
select(func.count(Issue.id)).where(
Issue.assigned_agent_id == instance_id
)
)
assigned_issues_count = issues_count_result.scalar_one()
return {
"instance": instance,
"agent_type_name": instance.agent_type.name
if instance.agent_type
else None,
"agent_type_slug": instance.agent_type.slug
if instance.agent_type
else None,
"project_name": instance.project.name if instance.project else None,
"project_slug": instance.project.slug if instance.project else None,
"assigned_issues_count": assigned_issues_count,
}
except Exception as e:
logger.error(
f"Error getting agent instance with details {instance_id}: {e!s}",
exc_info=True,
)
raise
async def get_by_project(
self,
db: AsyncSession,
*,
project_id: UUID,
status: AgentStatus | None = None,
skip: int = 0,
limit: int = 100,
) -> tuple[list[AgentInstance], int]:
"""Get agent instances for a specific project."""
try:
query = select(AgentInstance).where(AgentInstance.project_id == project_id)
if status is not None:
query = query.where(AgentInstance.status == status)
# Get total count
count_query = select(func.count()).select_from(query.alias())
count_result = await db.execute(count_query)
total = count_result.scalar_one()
# Apply pagination
query = query.order_by(AgentInstance.created_at.desc())
query = query.offset(skip).limit(limit)
result = await db.execute(query)
instances = list(result.scalars().all())
return instances, total
except Exception as e:
logger.error(
f"Error getting instances by project {project_id}: {e!s}",
exc_info=True,
)
raise
async def get_by_agent_type(
self,
db: AsyncSession,
*,
agent_type_id: UUID,
status: AgentStatus | None = None,
) -> list[AgentInstance]:
"""Get all instances of a specific agent type."""
try:
query = select(AgentInstance).where(
AgentInstance.agent_type_id == agent_type_id
)
if status is not None:
query = query.where(AgentInstance.status == status)
query = query.order_by(AgentInstance.created_at.desc())
result = await db.execute(query)
return list(result.scalars().all())
except Exception as e:
logger.error(
f"Error getting instances by agent type {agent_type_id}: {e!s}",
exc_info=True,
)
raise
async def update_status(
self,
db: AsyncSession,
*,
instance_id: UUID,
status: AgentStatus,
current_task: str | None = None,
) -> AgentInstance | None:
"""Update the status of an agent instance."""
try:
result = await db.execute(
select(AgentInstance).where(AgentInstance.id == instance_id)
)
instance = result.scalar_one_or_none()
if not instance:
return None
instance.status = status
instance.last_activity_at = datetime.now(UTC)
if current_task is not None:
instance.current_task = current_task
await db.commit()
await db.refresh(instance)
return instance
except Exception as e:
await db.rollback()
logger.error(
f"Error updating instance status {instance_id}: {e!s}", exc_info=True
)
raise
async def terminate(
self,
db: AsyncSession,
*,
instance_id: UUID,
) -> AgentInstance | None:
"""Terminate an agent instance.
Also unassigns all issues from this agent to prevent orphaned assignments.
"""
try:
result = await db.execute(
select(AgentInstance).where(AgentInstance.id == instance_id)
)
instance = result.scalar_one_or_none()
if not instance:
return None
# Unassign all issues from this agent before terminating
await db.execute(
update(Issue)
.where(Issue.assigned_agent_id == instance_id)
.values(assigned_agent_id=None)
)
instance.status = AgentStatus.TERMINATED
instance.terminated_at = datetime.now(UTC)
instance.current_task = None
instance.session_id = None
await db.commit()
await db.refresh(instance)
return instance
except Exception as e:
await db.rollback()
logger.error(
f"Error terminating instance {instance_id}: {e!s}", exc_info=True
)
raise
async def record_task_completion(
self,
db: AsyncSession,
*,
instance_id: UUID,
tokens_used: int,
cost_incurred: Decimal,
) -> AgentInstance | None:
"""Record a completed task and update metrics.
Uses atomic SQL UPDATE to prevent lost updates under concurrent load.
This avoids the read-modify-write race condition that occurs when
multiple task completions happen simultaneously.
"""
try:
now = datetime.now(UTC)
# Use atomic SQL UPDATE to increment counters without race conditions
# This is safe for concurrent updates - no read-modify-write pattern
result = await db.execute(
update(AgentInstance)
.where(AgentInstance.id == instance_id)
.values(
tasks_completed=AgentInstance.tasks_completed + 1,
tokens_used=AgentInstance.tokens_used + tokens_used,
cost_incurred=AgentInstance.cost_incurred + cost_incurred,
last_activity_at=now,
updated_at=now,
)
.returning(AgentInstance)
)
instance = result.scalar_one_or_none()
if not instance:
return None
await db.commit()
return instance
except Exception as e:
await db.rollback()
logger.error(
f"Error recording task completion {instance_id}: {e!s}", exc_info=True
)
raise
async def get_project_metrics(
self,
db: AsyncSession,
*,
project_id: UUID,
) -> dict[str, Any]:
"""Get aggregated metrics for all agents in a project."""
try:
result = await db.execute(
select(
func.count(AgentInstance.id).label("total_instances"),
func.count(AgentInstance.id)
.filter(AgentInstance.status == AgentStatus.WORKING)
.label("active_instances"),
func.count(AgentInstance.id)
.filter(AgentInstance.status == AgentStatus.IDLE)
.label("idle_instances"),
func.sum(AgentInstance.tasks_completed).label("total_tasks"),
func.sum(AgentInstance.tokens_used).label("total_tokens"),
func.sum(AgentInstance.cost_incurred).label("total_cost"),
).where(AgentInstance.project_id == project_id)
)
row = result.one()
return {
"total_instances": row.total_instances or 0,
"active_instances": row.active_instances or 0,
"idle_instances": row.idle_instances or 0,
"total_tasks_completed": row.total_tasks or 0,
"total_tokens_used": row.total_tokens or 0,
"total_cost_incurred": row.total_cost or Decimal("0.0000"),
}
except Exception as e:
logger.error(
f"Error getting project metrics {project_id}: {e!s}", exc_info=True
)
raise
async def bulk_terminate_by_project(
self,
db: AsyncSession,
*,
project_id: UUID,
) -> int:
"""Terminate all active instances in a project.
Also unassigns all issues from these agents to prevent orphaned assignments.
"""
try:
# First, unassign all issues from agents in this project
# Get all agent IDs that will be terminated
agents_to_terminate = await db.execute(
select(AgentInstance.id).where(
AgentInstance.project_id == project_id,
AgentInstance.status != AgentStatus.TERMINATED,
)
)
agent_ids = [row[0] for row in agents_to_terminate.fetchall()]
# Unassign issues from these agents
if agent_ids:
await db.execute(
update(Issue)
.where(Issue.assigned_agent_id.in_(agent_ids))
.values(assigned_agent_id=None)
)
now = datetime.now(UTC)
stmt = (
update(AgentInstance)
.where(
AgentInstance.project_id == project_id,
AgentInstance.status != AgentStatus.TERMINATED,
)
.values(
status=AgentStatus.TERMINATED,
terminated_at=now,
current_task=None,
session_id=None,
updated_at=now,
)
)
result = await db.execute(stmt)
await db.commit()
terminated_count = result.rowcount
logger.info(
f"Bulk terminated {terminated_count} instances in project {project_id}"
)
return terminated_count
except Exception as e:
await db.rollback()
logger.error(
f"Error bulk terminating instances for project {project_id}: {e!s}",
exc_info=True,
)
raise
# Create a singleton instance for use across the application
agent_instance = CRUDAgentInstance(AgentInstance)