forked from cardosofelipe/pragma-stack
Reformatted multiline function calls, object definitions, and queries for improved code readability and consistency. Adjusted imports and constraints where necessary.
395 lines
13 KiB
Python
395 lines
13 KiB
Python
# app/crud/syndarix/agent_instance.py
|
|
"""Async CRUD operations for AgentInstance model using SQLAlchemy 2.0 patterns."""
|
|
|
|
import logging
|
|
from datetime import UTC, datetime
|
|
from decimal import Decimal
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import func, select, update
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
from app.crud.base import CRUDBase
|
|
from app.models.syndarix import AgentInstance, Issue
|
|
from app.models.syndarix.enums import AgentStatus
|
|
from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CRUDAgentInstance(
|
|
CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]
|
|
):
|
|
"""Async CRUD operations for AgentInstance model."""
|
|
|
|
async def create(
|
|
self, db: AsyncSession, *, obj_in: AgentInstanceCreate
|
|
) -> AgentInstance:
|
|
"""Create a new agent instance with error handling."""
|
|
try:
|
|
db_obj = AgentInstance(
|
|
agent_type_id=obj_in.agent_type_id,
|
|
project_id=obj_in.project_id,
|
|
name=obj_in.name,
|
|
status=obj_in.status,
|
|
current_task=obj_in.current_task,
|
|
short_term_memory=obj_in.short_term_memory,
|
|
long_term_memory_ref=obj_in.long_term_memory_ref,
|
|
session_id=obj_in.session_id,
|
|
)
|
|
db.add(db_obj)
|
|
await db.commit()
|
|
await db.refresh(db_obj)
|
|
return db_obj
|
|
except IntegrityError as e:
|
|
await db.rollback()
|
|
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
|
logger.error(f"Integrity error creating agent instance: {error_msg}")
|
|
raise ValueError(f"Database integrity error: {error_msg}")
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Unexpected error creating agent instance: {e!s}", exc_info=True
|
|
)
|
|
raise
|
|
|
|
async def get_with_details(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
instance_id: UUID,
|
|
) -> dict[str, Any] | None:
|
|
"""
|
|
Get an agent instance with full details including related entities.
|
|
|
|
Returns:
|
|
Dictionary with instance and related entity details
|
|
"""
|
|
try:
|
|
# Get instance with joined relationships
|
|
result = await db.execute(
|
|
select(AgentInstance)
|
|
.options(
|
|
joinedload(AgentInstance.agent_type),
|
|
joinedload(AgentInstance.project),
|
|
)
|
|
.where(AgentInstance.id == instance_id)
|
|
)
|
|
instance = result.scalar_one_or_none()
|
|
|
|
if not instance:
|
|
return None
|
|
|
|
# Get assigned issues count
|
|
issues_count_result = await db.execute(
|
|
select(func.count(Issue.id)).where(
|
|
Issue.assigned_agent_id == instance_id
|
|
)
|
|
)
|
|
assigned_issues_count = issues_count_result.scalar_one()
|
|
|
|
return {
|
|
"instance": instance,
|
|
"agent_type_name": instance.agent_type.name
|
|
if instance.agent_type
|
|
else None,
|
|
"agent_type_slug": instance.agent_type.slug
|
|
if instance.agent_type
|
|
else None,
|
|
"project_name": instance.project.name if instance.project else None,
|
|
"project_slug": instance.project.slug if instance.project else None,
|
|
"assigned_issues_count": assigned_issues_count,
|
|
}
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting agent instance with details {instance_id}: {e!s}",
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
|
|
async def get_by_project(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
project_id: UUID,
|
|
status: AgentStatus | None = None,
|
|
skip: int = 0,
|
|
limit: int = 100,
|
|
) -> tuple[list[AgentInstance], int]:
|
|
"""Get agent instances for a specific project."""
|
|
try:
|
|
query = select(AgentInstance).where(AgentInstance.project_id == project_id)
|
|
|
|
if status is not None:
|
|
query = query.where(AgentInstance.status == status)
|
|
|
|
# Get total count
|
|
count_query = select(func.count()).select_from(query.alias())
|
|
count_result = await db.execute(count_query)
|
|
total = count_result.scalar_one()
|
|
|
|
# Apply pagination
|
|
query = query.order_by(AgentInstance.created_at.desc())
|
|
query = query.offset(skip).limit(limit)
|
|
result = await db.execute(query)
|
|
instances = list(result.scalars().all())
|
|
|
|
return instances, total
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting instances by project {project_id}: {e!s}",
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
|
|
async def get_by_agent_type(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
agent_type_id: UUID,
|
|
status: AgentStatus | None = None,
|
|
) -> list[AgentInstance]:
|
|
"""Get all instances of a specific agent type."""
|
|
try:
|
|
query = select(AgentInstance).where(
|
|
AgentInstance.agent_type_id == agent_type_id
|
|
)
|
|
|
|
if status is not None:
|
|
query = query.where(AgentInstance.status == status)
|
|
|
|
query = query.order_by(AgentInstance.created_at.desc())
|
|
result = await db.execute(query)
|
|
return list(result.scalars().all())
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting instances by agent type {agent_type_id}: {e!s}",
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
|
|
async def update_status(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
instance_id: UUID,
|
|
status: AgentStatus,
|
|
current_task: str | None = None,
|
|
) -> AgentInstance | None:
|
|
"""Update the status of an agent instance."""
|
|
try:
|
|
result = await db.execute(
|
|
select(AgentInstance).where(AgentInstance.id == instance_id)
|
|
)
|
|
instance = result.scalar_one_or_none()
|
|
|
|
if not instance:
|
|
return None
|
|
|
|
instance.status = status
|
|
instance.last_activity_at = datetime.now(UTC)
|
|
if current_task is not None:
|
|
instance.current_task = current_task
|
|
|
|
await db.commit()
|
|
await db.refresh(instance)
|
|
return instance
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error updating instance status {instance_id}: {e!s}", exc_info=True
|
|
)
|
|
raise
|
|
|
|
async def terminate(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
instance_id: UUID,
|
|
) -> AgentInstance | None:
|
|
"""Terminate an agent instance.
|
|
|
|
Also unassigns all issues from this agent to prevent orphaned assignments.
|
|
"""
|
|
try:
|
|
result = await db.execute(
|
|
select(AgentInstance).where(AgentInstance.id == instance_id)
|
|
)
|
|
instance = result.scalar_one_or_none()
|
|
|
|
if not instance:
|
|
return None
|
|
|
|
# Unassign all issues from this agent before terminating
|
|
await db.execute(
|
|
update(Issue)
|
|
.where(Issue.assigned_agent_id == instance_id)
|
|
.values(assigned_agent_id=None)
|
|
)
|
|
|
|
instance.status = AgentStatus.TERMINATED
|
|
instance.terminated_at = datetime.now(UTC)
|
|
instance.current_task = None
|
|
instance.session_id = None
|
|
|
|
await db.commit()
|
|
await db.refresh(instance)
|
|
return instance
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error terminating instance {instance_id}: {e!s}", exc_info=True
|
|
)
|
|
raise
|
|
|
|
async def record_task_completion(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
instance_id: UUID,
|
|
tokens_used: int,
|
|
cost_incurred: Decimal,
|
|
) -> AgentInstance | None:
|
|
"""Record a completed task and update metrics.
|
|
|
|
Uses atomic SQL UPDATE to prevent lost updates under concurrent load.
|
|
This avoids the read-modify-write race condition that occurs when
|
|
multiple task completions happen simultaneously.
|
|
"""
|
|
try:
|
|
now = datetime.now(UTC)
|
|
|
|
# Use atomic SQL UPDATE to increment counters without race conditions
|
|
# This is safe for concurrent updates - no read-modify-write pattern
|
|
result = await db.execute(
|
|
update(AgentInstance)
|
|
.where(AgentInstance.id == instance_id)
|
|
.values(
|
|
tasks_completed=AgentInstance.tasks_completed + 1,
|
|
tokens_used=AgentInstance.tokens_used + tokens_used,
|
|
cost_incurred=AgentInstance.cost_incurred + cost_incurred,
|
|
last_activity_at=now,
|
|
updated_at=now,
|
|
)
|
|
.returning(AgentInstance)
|
|
)
|
|
instance = result.scalar_one_or_none()
|
|
|
|
if not instance:
|
|
return None
|
|
|
|
await db.commit()
|
|
return instance
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error recording task completion {instance_id}: {e!s}", exc_info=True
|
|
)
|
|
raise
|
|
|
|
async def get_project_metrics(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
project_id: UUID,
|
|
) -> dict[str, Any]:
|
|
"""Get aggregated metrics for all agents in a project."""
|
|
try:
|
|
result = await db.execute(
|
|
select(
|
|
func.count(AgentInstance.id).label("total_instances"),
|
|
func.count(AgentInstance.id)
|
|
.filter(AgentInstance.status == AgentStatus.WORKING)
|
|
.label("active_instances"),
|
|
func.count(AgentInstance.id)
|
|
.filter(AgentInstance.status == AgentStatus.IDLE)
|
|
.label("idle_instances"),
|
|
func.sum(AgentInstance.tasks_completed).label("total_tasks"),
|
|
func.sum(AgentInstance.tokens_used).label("total_tokens"),
|
|
func.sum(AgentInstance.cost_incurred).label("total_cost"),
|
|
).where(AgentInstance.project_id == project_id)
|
|
)
|
|
row = result.one()
|
|
|
|
return {
|
|
"total_instances": row.total_instances or 0,
|
|
"active_instances": row.active_instances or 0,
|
|
"idle_instances": row.idle_instances or 0,
|
|
"total_tasks_completed": row.total_tasks or 0,
|
|
"total_tokens_used": row.total_tokens or 0,
|
|
"total_cost_incurred": row.total_cost or Decimal("0.0000"),
|
|
}
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting project metrics {project_id}: {e!s}", exc_info=True
|
|
)
|
|
raise
|
|
|
|
async def bulk_terminate_by_project(
|
|
self,
|
|
db: AsyncSession,
|
|
*,
|
|
project_id: UUID,
|
|
) -> int:
|
|
"""Terminate all active instances in a project.
|
|
|
|
Also unassigns all issues from these agents to prevent orphaned assignments.
|
|
"""
|
|
try:
|
|
# First, unassign all issues from agents in this project
|
|
# Get all agent IDs that will be terminated
|
|
agents_to_terminate = await db.execute(
|
|
select(AgentInstance.id).where(
|
|
AgentInstance.project_id == project_id,
|
|
AgentInstance.status != AgentStatus.TERMINATED,
|
|
)
|
|
)
|
|
agent_ids = [row[0] for row in agents_to_terminate.fetchall()]
|
|
|
|
# Unassign issues from these agents
|
|
if agent_ids:
|
|
await db.execute(
|
|
update(Issue)
|
|
.where(Issue.assigned_agent_id.in_(agent_ids))
|
|
.values(assigned_agent_id=None)
|
|
)
|
|
|
|
now = datetime.now(UTC)
|
|
stmt = (
|
|
update(AgentInstance)
|
|
.where(
|
|
AgentInstance.project_id == project_id,
|
|
AgentInstance.status != AgentStatus.TERMINATED,
|
|
)
|
|
.values(
|
|
status=AgentStatus.TERMINATED,
|
|
terminated_at=now,
|
|
current_task=None,
|
|
session_id=None,
|
|
updated_at=now,
|
|
)
|
|
)
|
|
|
|
result = await db.execute(stmt)
|
|
await db.commit()
|
|
|
|
terminated_count = result.rowcount
|
|
logger.info(
|
|
f"Bulk terminated {terminated_count} instances in project {project_id}"
|
|
)
|
|
return terminated_count
|
|
except Exception as e:
|
|
await db.rollback()
|
|
logger.error(
|
|
f"Error bulk terminating instances for project {project_id}: {e!s}",
|
|
exc_info=True,
|
|
)
|
|
raise
|
|
|
|
|
|
# Create a singleton instance for use across the application
|
|
agent_instance = CRUDAgentInstance(AgentInstance)
|