forked from cardosofelipe/pragma-stack
Compare commits
6 Commits
d6db6af964
...
2055320058
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2055320058 | ||
|
|
11da0d57a8 | ||
|
|
acfda1e9a9 | ||
|
|
3c24a8c522 | ||
|
|
ec111f9ce6 | ||
|
|
520a4d60fb |
@@ -0,0 +1,66 @@
|
||||
"""Enable pgvector extension
|
||||
|
||||
Revision ID: 0003
|
||||
Revises: 0002
|
||||
Create Date: 2025-12-30
|
||||
|
||||
This migration enables the pgvector extension for PostgreSQL, which provides
|
||||
vector similarity search capabilities required for the RAG (Retrieval-Augmented
|
||||
Generation) knowledge base system.
|
||||
|
||||
Vector Dimension Reference (per ADR-008 and SPIKE-006):
|
||||
---------------------------------------------------------
|
||||
The dimension size depends on the embedding model used:
|
||||
|
||||
| Model | Dimensions | Use Case |
|
||||
|----------------------------|------------|------------------------------|
|
||||
| text-embedding-3-small | 1536 | General docs, conversations |
|
||||
| text-embedding-3-large | 256-3072 | High accuracy (configurable) |
|
||||
| voyage-code-3 | 1024 | Code files (Python, JS, etc) |
|
||||
| voyage-3-large | 1024 | High quality general purpose |
|
||||
| nomic-embed-text (Ollama) | 768 | Local/fallback embedding |
|
||||
|
||||
Recommended defaults for Syndarix:
|
||||
- Documentation/conversations: 1536 (text-embedding-3-small)
|
||||
- Code files: 1024 (voyage-code-3)
|
||||
|
||||
Prerequisites:
|
||||
--------------
|
||||
This migration requires PostgreSQL with the pgvector extension installed.
|
||||
The Docker Compose configuration uses `pgvector/pgvector:pg17` which includes
|
||||
the extension pre-installed.
|
||||
|
||||
References:
|
||||
-----------
|
||||
- ADR-008: Knowledge Base and RAG Architecture
|
||||
- SPIKE-006: Knowledge Base with pgvector for RAG System
|
||||
- https://github.com/pgvector/pgvector
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0003"
|
||||
down_revision: str | None = "0002"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Enable the pgvector extension.
|
||||
|
||||
The CREATE EXTENSION IF NOT EXISTS statement is idempotent - it will
|
||||
succeed whether the extension already exists or not.
|
||||
"""
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop the pgvector extension.
|
||||
|
||||
Note: This will fail if any tables with vector columns exist.
|
||||
Future migrations that create vector columns should be downgraded first.
|
||||
"""
|
||||
op.execute("DROP EXTENSION IF EXISTS vector")
|
||||
36
backend/app/api/dependencies/event_bus.py
Normal file
36
backend/app/api/dependencies/event_bus.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Event bus dependency for FastAPI routes.
|
||||
|
||||
This module provides the FastAPI dependency for injecting the EventBus
|
||||
into route handlers. The event bus is a singleton that maintains
|
||||
Redis pub/sub connections for real-time event streaming.
|
||||
"""
|
||||
|
||||
from app.services.event_bus import (
|
||||
EventBus,
|
||||
get_connected_event_bus as _get_connected_event_bus,
|
||||
)
|
||||
|
||||
|
||||
async def get_event_bus() -> EventBus:
|
||||
"""
|
||||
FastAPI dependency that provides a connected EventBus instance.
|
||||
|
||||
The EventBus is a singleton that maintains Redis pub/sub connections.
|
||||
It's lazily initialized and connected on first access, and should be
|
||||
closed during application shutdown via close_event_bus().
|
||||
|
||||
Usage:
|
||||
@router.get("/events/stream")
|
||||
async def stream_events(
|
||||
event_bus: EventBus = Depends(get_event_bus)
|
||||
):
|
||||
...
|
||||
|
||||
Returns:
|
||||
EventBus: The global connected event bus instance
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection to Redis fails
|
||||
"""
|
||||
return await _get_connected_event_bus()
|
||||
@@ -3,6 +3,7 @@ from fastapi import APIRouter
|
||||
from app.api.routes import (
|
||||
admin,
|
||||
auth,
|
||||
events,
|
||||
oauth,
|
||||
oauth_provider,
|
||||
organizations,
|
||||
@@ -22,3 +23,5 @@ api_router.include_router(admin.router, prefix="/admin", tags=["Admin"])
|
||||
api_router.include_router(
|
||||
organizations.router, prefix="/organizations", tags=["Organizations"]
|
||||
)
|
||||
# SSE events router - no prefix, routes define full paths
|
||||
api_router.include_router(events.router, tags=["Events"])
|
||||
|
||||
283
backend/app/api/routes/events.py
Normal file
283
backend/app/api/routes/events.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
SSE endpoint for real-time project event streaming.
|
||||
|
||||
This module provides Server-Sent Events (SSE) endpoints for streaming
|
||||
project events to connected clients. Events are scoped to projects,
|
||||
with authorization checks to ensure clients only receive events
|
||||
for projects they have access to.
|
||||
|
||||
Features:
|
||||
- Real-time event streaming via SSE
|
||||
- Project-scoped authorization
|
||||
- Automatic reconnection support (Last-Event-ID)
|
||||
- Keepalive messages every 30 seconds
|
||||
- Graceful connection cleanup
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, Request
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from app.api.dependencies.auth import get_current_user
|
||||
from app.api.dependencies.event_bus import get_event_bus
|
||||
from app.core.exceptions import AuthorizationError
|
||||
from app.models.user import User
|
||||
from app.schemas.errors import ErrorCode
|
||||
from app.schemas.events import EventType
|
||||
from app.services.event_bus import EventBus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
|
||||
# Keepalive interval in seconds
|
||||
KEEPALIVE_INTERVAL = 30
|
||||
|
||||
|
||||
async def check_project_access(
|
||||
project_id: UUID,
|
||||
user: User,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a user has access to a project's events.
|
||||
|
||||
This is a placeholder implementation that will be replaced
|
||||
with actual project authorization logic once the Project model
|
||||
is implemented. Currently allows access for all authenticated users.
|
||||
|
||||
Args:
|
||||
project_id: The project to check access for
|
||||
user: The authenticated user
|
||||
|
||||
Returns:
|
||||
bool: True if user has access, False otherwise
|
||||
|
||||
TODO: Implement actual project authorization
|
||||
- Check if user owns the project
|
||||
- Check if user is a member of the project
|
||||
- Check project visibility settings
|
||||
"""
|
||||
# Placeholder: Allow all authenticated users for now
|
||||
# This will be replaced with actual project ownership/membership check
|
||||
logger.debug(
|
||||
f"Project access check for user {user.id} on project {project_id} "
|
||||
"(placeholder: allowing all authenticated users)"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def event_generator(
|
||||
project_id: UUID,
|
||||
event_bus: EventBus,
|
||||
last_event_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Generate SSE events for a project.
|
||||
|
||||
This async generator yields SSE-formatted events from the event bus,
|
||||
including keepalive comments to maintain the connection.
|
||||
|
||||
Args:
|
||||
project_id: The project to stream events for
|
||||
event_bus: The EventBus instance
|
||||
last_event_id: Optional last received event ID for reconnection
|
||||
|
||||
Yields:
|
||||
dict: SSE event data with 'event', 'data', and optional 'id' fields
|
||||
"""
|
||||
try:
|
||||
async for event_data in event_bus.subscribe_sse(
|
||||
project_id=project_id,
|
||||
last_event_id=last_event_id,
|
||||
keepalive_interval=KEEPALIVE_INTERVAL,
|
||||
):
|
||||
if event_data == "":
|
||||
# Keepalive - yield SSE comment
|
||||
yield {"comment": "keepalive"}
|
||||
else:
|
||||
# Parse event to extract type and id
|
||||
try:
|
||||
event_dict = json.loads(event_data)
|
||||
event_type = event_dict.get("type", "message")
|
||||
event_id = event_dict.get("id")
|
||||
|
||||
yield {
|
||||
"event": event_type,
|
||||
"data": event_data,
|
||||
"id": event_id,
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# If we can't parse, send as generic message
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": event_data,
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Event stream cancelled for project {project_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event stream for project {project_id}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/projects/{project_id}/events/stream",
|
||||
summary="Stream Project Events",
|
||||
description="""
|
||||
Stream real-time events for a project via Server-Sent Events (SSE).
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Must have access to the project
|
||||
|
||||
**SSE Event Format**:
|
||||
```
|
||||
event: agent.status_changed
|
||||
id: 550e8400-e29b-41d4-a716-446655440000
|
||||
data: {"id": "...", "type": "agent.status_changed", "project_id": "...", ...}
|
||||
|
||||
: keepalive
|
||||
|
||||
event: issue.created
|
||||
id: 550e8400-e29b-41d4-a716-446655440001
|
||||
data: {...}
|
||||
```
|
||||
|
||||
**Reconnection**: Include the `Last-Event-ID` header with the last received
|
||||
event ID to resume from where you left off.
|
||||
|
||||
**Keepalive**: The server sends a comment (`: keepalive`) every 30 seconds
|
||||
to keep the connection alive.
|
||||
|
||||
**Rate Limit**: 10 connections/minute per IP
|
||||
""",
|
||||
response_class=EventSourceResponse,
|
||||
responses={
|
||||
200: {
|
||||
"description": "SSE stream established",
|
||||
"content": {"text/event-stream": {}},
|
||||
},
|
||||
401: {"description": "Not authenticated"},
|
||||
403: {"description": "Not authorized to access this project"},
|
||||
404: {"description": "Project not found"},
|
||||
},
|
||||
operation_id="stream_project_events",
|
||||
)
|
||||
@limiter.limit("10/minute")
|
||||
async def stream_project_events(
|
||||
request: Request,
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
event_bus: EventBus = Depends(get_event_bus),
|
||||
last_event_id: str | None = Header(None, alias="Last-Event-ID"),
|
||||
):
|
||||
"""
|
||||
Stream real-time events for a project via SSE.
|
||||
|
||||
This endpoint establishes a persistent SSE connection that streams
|
||||
project events to the client in real-time. The connection includes:
|
||||
|
||||
- Event streaming: All project events (agent updates, issues, etc.)
|
||||
- Keepalive: Comment every 30 seconds to maintain connection
|
||||
- Reconnection: Use Last-Event-ID header to resume after disconnect
|
||||
|
||||
The connection is automatically cleaned up when the client disconnects.
|
||||
"""
|
||||
logger.info(
|
||||
f"SSE connection request for project {project_id} "
|
||||
f"by user {current_user.id} "
|
||||
f"(last_event_id={last_event_id})"
|
||||
)
|
||||
|
||||
# Check project access
|
||||
has_access = await check_project_access(project_id, current_user)
|
||||
if not has_access:
|
||||
raise AuthorizationError(
|
||||
message=f"You don't have access to project {project_id}",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Return SSE response
|
||||
return EventSourceResponse(
|
||||
event_generator(
|
||||
project_id=project_id,
|
||||
event_bus=event_bus,
|
||||
last_event_id=last_event_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/projects/{project_id}/events/test",
|
||||
summary="Send Test Event (Development Only)",
|
||||
description="""
|
||||
Send a test event to a project's event stream. This endpoint is
|
||||
intended for development and testing purposes.
|
||||
|
||||
**Authentication**: Required (Bearer token)
|
||||
**Authorization**: Must have access to the project
|
||||
|
||||
**Note**: This endpoint should be disabled or restricted in production.
|
||||
""",
|
||||
response_model=dict,
|
||||
responses={
|
||||
200: {"description": "Test event sent"},
|
||||
401: {"description": "Not authenticated"},
|
||||
403: {"description": "Not authorized to access this project"},
|
||||
},
|
||||
operation_id="send_test_event",
|
||||
)
|
||||
async def send_test_event(
|
||||
project_id: UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
event_bus: EventBus = Depends(get_event_bus),
|
||||
):
|
||||
"""
|
||||
Send a test event to the project's event stream.
|
||||
|
||||
This is useful for testing SSE connections during development.
|
||||
"""
|
||||
# Check project access
|
||||
has_access = await check_project_access(project_id, current_user)
|
||||
if not has_access:
|
||||
raise AuthorizationError(
|
||||
message=f"You don't have access to project {project_id}",
|
||||
error_code=ErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
)
|
||||
|
||||
# Create and publish test event using the Event schema
|
||||
event = EventBus.create_event(
|
||||
event_type=EventType.AGENT_MESSAGE,
|
||||
project_id=project_id,
|
||||
actor_type="user",
|
||||
actor_id=current_user.id,
|
||||
payload={
|
||||
"message": "Test event from SSE endpoint",
|
||||
"message_type": "info",
|
||||
},
|
||||
)
|
||||
|
||||
channel = event_bus.get_project_channel(project_id)
|
||||
await event_bus.publish(channel, event)
|
||||
|
||||
logger.info(f"Test event sent to project {project_id}: {event.id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"event_id": event.id,
|
||||
"event_type": event.type.value,
|
||||
"message": "Test event sent successfully",
|
||||
}
|
||||
110
backend/app/celery_app.py
Normal file
110
backend/app/celery_app.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# app/celery_app.py
|
||||
"""
|
||||
Celery application configuration for Syndarix.
|
||||
|
||||
This module configures the Celery app for background task processing:
|
||||
- Agent execution tasks (LLM calls, tool execution)
|
||||
- Git operations (clone, commit, push, PR creation)
|
||||
- Issue synchronization with external trackers
|
||||
- Workflow state management
|
||||
- Cost tracking and budget monitoring
|
||||
|
||||
Architecture:
|
||||
- Redis as message broker and result backend
|
||||
- Queue routing for task isolation
|
||||
- JSON serialization for cross-language compatibility
|
||||
- Beat scheduler for periodic tasks
|
||||
"""
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Create Celery application instance
|
||||
celery_app = Celery(
|
||||
"syndarix",
|
||||
broker=settings.celery_broker_url,
|
||||
backend=settings.celery_result_backend,
|
||||
)
|
||||
|
||||
# Define task queues with their own exchanges and routing keys
|
||||
TASK_QUEUES = {
|
||||
"agent": {"exchange": "agent", "routing_key": "agent"},
|
||||
"git": {"exchange": "git", "routing_key": "git"},
|
||||
"sync": {"exchange": "sync", "routing_key": "sync"},
|
||||
"default": {"exchange": "default", "routing_key": "default"},
|
||||
}
|
||||
|
||||
# Configure Celery
|
||||
celery_app.conf.update(
|
||||
# Serialization
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
# Timezone
|
||||
timezone="UTC",
|
||||
enable_utc=True,
|
||||
# Task imports for auto-discovery
|
||||
imports=("app.tasks",),
|
||||
# Default queue
|
||||
task_default_queue="default",
|
||||
# Task queues configuration
|
||||
task_queues=TASK_QUEUES,
|
||||
# Task routing - route tasks to appropriate queues
|
||||
task_routes={
|
||||
"app.tasks.agent.*": {"queue": "agent"},
|
||||
"app.tasks.git.*": {"queue": "git"},
|
||||
"app.tasks.sync.*": {"queue": "sync"},
|
||||
"app.tasks.*": {"queue": "default"},
|
||||
},
|
||||
# Time limits per ADR-003
|
||||
task_soft_time_limit=300, # 5 minutes soft limit
|
||||
task_time_limit=600, # 10 minutes hard limit
|
||||
# Result expiration - 24 hours
|
||||
result_expires=86400,
|
||||
# Broker connection retry
|
||||
broker_connection_retry_on_startup=True,
|
||||
# Beat schedule for periodic tasks
|
||||
beat_schedule={
|
||||
# Cost aggregation every hour per ADR-012
|
||||
"aggregate-daily-costs": {
|
||||
"task": "app.tasks.cost.aggregate_daily_costs",
|
||||
"schedule": 3600.0, # 1 hour in seconds
|
||||
},
|
||||
# Reset daily budget counters at midnight UTC
|
||||
"reset-daily-budget-counters": {
|
||||
"task": "app.tasks.cost.reset_daily_budget_counters",
|
||||
"schedule": 86400.0, # 24 hours in seconds
|
||||
},
|
||||
# Check for stale workflows every 5 minutes
|
||||
"recover-stale-workflows": {
|
||||
"task": "app.tasks.workflow.recover_stale_workflows",
|
||||
"schedule": 300.0, # 5 minutes in seconds
|
||||
},
|
||||
# Incremental issue sync every minute per ADR-011
|
||||
"sync-issues-incremental": {
|
||||
"task": "app.tasks.sync.sync_issues_incremental",
|
||||
"schedule": 60.0, # 1 minute in seconds
|
||||
},
|
||||
# Full issue reconciliation every 15 minutes per ADR-011
|
||||
"sync-issues-full": {
|
||||
"task": "app.tasks.sync.sync_issues_full",
|
||||
"schedule": 900.0, # 15 minutes in seconds
|
||||
},
|
||||
},
|
||||
# Task execution settings
|
||||
task_acks_late=True, # Acknowledge tasks after execution
|
||||
task_reject_on_worker_lost=True, # Reject tasks if worker dies
|
||||
worker_prefetch_multiplier=1, # Fair task distribution
|
||||
)
|
||||
|
||||
# Auto-discover tasks from task modules
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"app.tasks.agent",
|
||||
"app.tasks.git",
|
||||
"app.tasks.sync",
|
||||
"app.tasks.workflow",
|
||||
"app.tasks.cost",
|
||||
]
|
||||
)
|
||||
@@ -39,6 +39,32 @@ class Settings(BaseSettings):
|
||||
db_pool_timeout: int = 30 # Seconds to wait for a connection
|
||||
db_pool_recycle: int = 3600 # Recycle connections after 1 hour
|
||||
|
||||
# Redis configuration (Syndarix: cache, pub/sub, Celery broker)
|
||||
REDIS_URL: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis URL for cache, pub/sub, and Celery broker",
|
||||
)
|
||||
|
||||
# Celery configuration (Syndarix: background task processing)
|
||||
CELERY_BROKER_URL: str | None = Field(
|
||||
default=None,
|
||||
description="Celery broker URL (defaults to REDIS_URL if not set)",
|
||||
)
|
||||
CELERY_RESULT_BACKEND: str | None = Field(
|
||||
default=None,
|
||||
description="Celery result backend URL (defaults to REDIS_URL if not set)",
|
||||
)
|
||||
|
||||
@property
|
||||
def celery_broker_url(self) -> str:
|
||||
"""Get Celery broker URL, defaulting to Redis."""
|
||||
return self.CELERY_BROKER_URL or self.REDIS_URL
|
||||
|
||||
@property
|
||||
def celery_result_backend(self) -> str:
|
||||
"""Get Celery result backend URL, defaulting to Redis."""
|
||||
return self.CELERY_RESULT_BACKEND or self.REDIS_URL
|
||||
|
||||
# SQL debugging (disable in production)
|
||||
sql_echo: bool = False # Log SQL statements
|
||||
sql_echo_pool: bool = False # Log connection pool events
|
||||
|
||||
476
backend/app/core/redis.py
Normal file
476
backend/app/core/redis.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# app/core/redis.py
|
||||
"""
|
||||
Redis client configuration for caching and pub/sub.
|
||||
|
||||
This module provides async Redis connectivity with connection pooling
|
||||
for FastAPI endpoints and background tasks.
|
||||
|
||||
Features:
|
||||
- Connection pooling for efficient resource usage
|
||||
- Cache operations (get, set, delete, expire)
|
||||
- Pub/sub operations (publish, subscribe)
|
||||
- Health check for monitoring
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
|
||||
from redis.asyncio import ConnectionPool, Redis
|
||||
from redis.asyncio.client import PubSub
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default TTL for cache entries (1 hour)
|
||||
DEFAULT_CACHE_TTL = 3600
|
||||
|
||||
# Connection pool settings
|
||||
POOL_MAX_CONNECTIONS = 50
|
||||
POOL_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
class RedisClient:
|
||||
"""
|
||||
Async Redis client with connection pooling.
|
||||
|
||||
Provides high-level operations for caching and pub/sub
|
||||
with proper error handling and connection management.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str | None = None) -> None:
|
||||
"""
|
||||
Initialize Redis client.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL. Defaults to settings.REDIS_URL.
|
||||
"""
|
||||
self._url = url or settings.REDIS_URL
|
||||
self._pool: ConnectionPool | None = None
|
||||
self._client: Redis | None = None
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_pool(self) -> ConnectionPool:
|
||||
"""Ensure connection pool is initialized (thread-safe)."""
|
||||
if self._pool is None:
|
||||
async with self._lock:
|
||||
# Double-check after acquiring lock
|
||||
if self._pool is None:
|
||||
self._pool = ConnectionPool.from_url(
|
||||
self._url,
|
||||
max_connections=POOL_MAX_CONNECTIONS,
|
||||
socket_timeout=POOL_TIMEOUT,
|
||||
socket_connect_timeout=POOL_TIMEOUT,
|
||||
decode_responses=True,
|
||||
health_check_interval=30,
|
||||
)
|
||||
logger.info("Redis connection pool initialized")
|
||||
return self._pool
|
||||
|
||||
async def _get_client(self) -> Redis:
|
||||
"""Get Redis client instance from pool."""
|
||||
pool = await self._ensure_pool()
|
||||
if self._client is None:
|
||||
self._client = Redis(connection_pool=pool)
|
||||
return self._client
|
||||
|
||||
# =========================================================================
|
||||
# Cache Operations
|
||||
# =========================================================================
|
||||
|
||||
async def cache_get(self, key: str) -> str | None:
|
||||
"""
|
||||
Get a value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
Cached value or None if not found.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
value = await client.get(key)
|
||||
if value is not None:
|
||||
logger.debug(f"Cache hit for key: {key}")
|
||||
else:
|
||||
logger.debug(f"Cache miss for key: {key}")
|
||||
return value
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_get failed for key '{key}': {e}")
|
||||
return None
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_get for key '{key}': {e}")
|
||||
return None
|
||||
|
||||
async def cache_get_json(self, key: str) -> Any | None:
|
||||
"""
|
||||
Get a JSON-serialized value from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
Deserialized value or None if not found.
|
||||
"""
|
||||
value = await self.cache_get(key)
|
||||
if value is not None:
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode JSON for key '{key}': {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
async def cache_set(
|
||||
self,
|
||||
key: str,
|
||||
value: str,
|
||||
ttl: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Set a value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
value: Value to cache.
|
||||
ttl: Time-to-live in seconds. Defaults to DEFAULT_CACHE_TTL.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
ttl = ttl if ttl is not None else DEFAULT_CACHE_TTL
|
||||
await client.set(key, value, ex=ttl)
|
||||
logger.debug(f"Cache set for key: {key} (TTL: {ttl}s)")
|
||||
return True
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_set failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_set for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_set_json(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Set a JSON-serialized value in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
value: Value to serialize and cache.
|
||||
ttl: Time-to-live in seconds.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
serialized = json.dumps(value)
|
||||
return await self.cache_set(key, serialized, ttl)
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error(f"Failed to serialize value for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_delete(self, key: str) -> bool:
|
||||
"""
|
||||
Delete a key from cache.
|
||||
|
||||
Args:
|
||||
key: Cache key to delete.
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.delete(key)
|
||||
logger.debug(f"Cache delete for key: {key} (deleted: {result > 0})")
|
||||
return result > 0
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_delete failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_delete for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_delete_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
Delete all keys matching a pattern.
|
||||
|
||||
Args:
|
||||
pattern: Glob-style pattern (e.g., "user:*").
|
||||
|
||||
Returns:
|
||||
Number of keys deleted.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
deleted = 0
|
||||
async for key in client.scan_iter(pattern):
|
||||
await client.delete(key)
|
||||
deleted += 1
|
||||
logger.debug(f"Cache delete pattern '{pattern}': {deleted} keys deleted")
|
||||
return deleted
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_delete_pattern failed for '{pattern}': {e}")
|
||||
return 0
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_delete_pattern for '{pattern}': {e}")
|
||||
return 0
|
||||
|
||||
async def cache_expire(self, key: str, ttl: int) -> bool:
|
||||
"""
|
||||
Set or update TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
ttl: New TTL in seconds.
|
||||
|
||||
Returns:
|
||||
True if TTL was set, False if key doesn't exist.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.expire(key, ttl)
|
||||
logger.debug(f"Cache expire for key: {key} (TTL: {ttl}s, success: {result})")
|
||||
return result
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_expire failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_expire for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_exists(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a key exists in cache.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
True if key exists, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.exists(key)
|
||||
return result > 0
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_exists failed for key '{key}': {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_exists for key '{key}': {e}")
|
||||
return False
|
||||
|
||||
async def cache_ttl(self, key: str) -> int:
|
||||
"""
|
||||
Get remaining TTL for a key.
|
||||
|
||||
Args:
|
||||
key: Cache key.
|
||||
|
||||
Returns:
|
||||
TTL in seconds, -1 if no TTL, -2 if key doesn't exist.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
return await client.ttl(key)
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis cache_ttl failed for key '{key}': {e}")
|
||||
return -2
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in cache_ttl for key '{key}': {e}")
|
||||
return -2
|
||||
|
||||
# =========================================================================
|
||||
# Pub/Sub Operations
|
||||
# =========================================================================
|
||||
|
||||
async def publish(self, channel: str, message: str | dict) -> int:
|
||||
"""
|
||||
Publish a message to a channel.
|
||||
|
||||
Args:
|
||||
channel: Channel name.
|
||||
message: Message to publish (string or dict for JSON serialization).
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
if isinstance(message, dict):
|
||||
message = json.dumps(message)
|
||||
result = await client.publish(channel, message)
|
||||
logger.debug(f"Published to channel '{channel}': {result} subscribers")
|
||||
return result
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis publish failed for channel '{channel}': {e}")
|
||||
return 0
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis error in publish for channel '{channel}': {e}")
|
||||
return 0
|
||||
|
||||
@asynccontextmanager
|
||||
async def subscribe(
|
||||
self, *channels: str
|
||||
) -> AsyncGenerator[PubSub, None]:
|
||||
"""
|
||||
Subscribe to one or more channels.
|
||||
|
||||
Usage:
|
||||
async with redis_client.subscribe("channel1", "channel2") as pubsub:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
print(message["data"])
|
||||
|
||||
Args:
|
||||
channels: Channel names to subscribe to.
|
||||
|
||||
Yields:
|
||||
PubSub instance for receiving messages.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
pubsub = client.pubsub()
|
||||
try:
|
||||
await pubsub.subscribe(*channels)
|
||||
logger.debug(f"Subscribed to channels: {channels}")
|
||||
yield pubsub
|
||||
finally:
|
||||
await pubsub.unsubscribe(*channels)
|
||||
await pubsub.close()
|
||||
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def psubscribe(
|
||||
self, *patterns: str
|
||||
) -> AsyncGenerator[PubSub, None]:
|
||||
"""
|
||||
Subscribe to channels matching patterns.
|
||||
|
||||
Usage:
|
||||
async with redis_client.psubscribe("user:*") as pubsub:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] == "pmessage":
|
||||
print(message["pattern"], message["channel"], message["data"])
|
||||
|
||||
Args:
|
||||
patterns: Glob-style patterns to subscribe to.
|
||||
|
||||
Yields:
|
||||
PubSub instance for receiving messages.
|
||||
"""
|
||||
client = await self._get_client()
|
||||
pubsub = client.pubsub()
|
||||
try:
|
||||
await pubsub.psubscribe(*patterns)
|
||||
logger.debug(f"Pattern subscribed: {patterns}")
|
||||
yield pubsub
|
||||
finally:
|
||||
await pubsub.punsubscribe(*patterns)
|
||||
await pubsub.close()
|
||||
logger.debug(f"Pattern unsubscribed: {patterns}")
|
||||
|
||||
# =========================================================================
|
||||
# Health & Connection Management
|
||||
# =========================================================================
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""
|
||||
Check if Redis connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
client = await self._get_client()
|
||||
result = await client.ping()
|
||||
return result is True
|
||||
except (ConnectionError, TimeoutError) as e:
|
||||
logger.error(f"Redis health check failed: {e}")
|
||||
return False
|
||||
except RedisError as e:
|
||||
logger.error(f"Redis health check error: {e}")
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
"""
|
||||
Close Redis connections and cleanup resources.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
if self._client:
|
||||
await self._client.close()
|
||||
self._client = None
|
||||
logger.debug("Redis client closed")
|
||||
|
||||
if self._pool:
|
||||
await self._pool.disconnect()
|
||||
self._pool = None
|
||||
logger.info("Redis connection pool closed")
|
||||
|
||||
async def get_pool_info(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get connection pool statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with pool information.
|
||||
"""
|
||||
if self._pool is None:
|
||||
return {"status": "not_initialized"}
|
||||
|
||||
return {
|
||||
"status": "active",
|
||||
"max_connections": POOL_MAX_CONNECTIONS,
|
||||
"url": self._url.split("@")[-1] if "@" in self._url else self._url,
|
||||
}
|
||||
|
||||
|
||||
# Global Redis client instance
|
||||
redis_client = RedisClient()
|
||||
|
||||
|
||||
# FastAPI dependency for Redis client
|
||||
async def get_redis() -> AsyncGenerator[RedisClient, None]:
|
||||
"""
|
||||
FastAPI dependency that provides the Redis client.
|
||||
|
||||
Usage:
|
||||
@router.get("/cached-data")
|
||||
async def get_data(redis: RedisClient = Depends(get_redis)):
|
||||
cached = await redis.cache_get("my-key")
|
||||
...
|
||||
"""
|
||||
yield redis_client
|
||||
|
||||
|
||||
# Health check function for use in /health endpoint
|
||||
async def check_redis_health() -> bool:
|
||||
"""
|
||||
Check if Redis connection is healthy.
|
||||
|
||||
Returns:
|
||||
True if connection is successful, False otherwise.
|
||||
"""
|
||||
return await redis_client.health_check()
|
||||
|
||||
|
||||
# Cleanup function for application shutdown
|
||||
async def close_redis() -> None:
|
||||
"""
|
||||
Close Redis connections.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
await redis_client.close()
|
||||
20
backend/app/crud/syndarix/__init__.py
Normal file
20
backend/app/crud/syndarix/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# app/crud/syndarix/__init__.py
|
||||
"""
|
||||
Syndarix CRUD operations.
|
||||
|
||||
This package contains CRUD operations for all Syndarix domain entities.
|
||||
"""
|
||||
|
||||
from .agent_instance import agent_instance
|
||||
from .agent_type import agent_type
|
||||
from .issue import issue
|
||||
from .project import project
|
||||
from .sprint import sprint
|
||||
|
||||
__all__ = [
|
||||
"agent_instance",
|
||||
"agent_type",
|
||||
"issue",
|
||||
"project",
|
||||
"sprint",
|
||||
]
|
||||
346
backend/app/crud/syndarix/agent_instance.py
Normal file
346
backend/app/crud/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# app/crud/syndarix/agent_instance.py
|
||||
"""Async CRUD operations for AgentInstance model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, Issue
|
||||
from app.models.syndarix.enums import AgentStatus
|
||||
from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDAgentInstance(CRUDBase[AgentInstance, AgentInstanceCreate, AgentInstanceUpdate]):
|
||||
"""Async CRUD operations for AgentInstance model."""
|
||||
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: AgentInstanceCreate
|
||||
) -> AgentInstance:
|
||||
"""Create a new agent instance with error handling."""
|
||||
try:
|
||||
db_obj = AgentInstance(
|
||||
agent_type_id=obj_in.agent_type_id,
|
||||
project_id=obj_in.project_id,
|
||||
status=obj_in.status,
|
||||
current_task=obj_in.current_task,
|
||||
short_term_memory=obj_in.short_term_memory,
|
||||
long_term_memory_ref=obj_in.long_term_memory_ref,
|
||||
session_id=obj_in.session_id,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error creating agent instance: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error creating agent instance: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get an agent instance with full details including related entities.
|
||||
|
||||
Returns:
|
||||
Dictionary with instance and related entity details
|
||||
"""
|
||||
try:
|
||||
# Get instance with joined relationships
|
||||
result = await db.execute(
|
||||
select(AgentInstance)
|
||||
.options(
|
||||
joinedload(AgentInstance.agent_type),
|
||||
joinedload(AgentInstance.project),
|
||||
)
|
||||
.where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
# Get assigned issues count
|
||||
issues_count_result = await db.execute(
|
||||
select(func.count(Issue.id)).where(
|
||||
Issue.assigned_agent_id == instance_id
|
||||
)
|
||||
)
|
||||
assigned_issues_count = issues_count_result.scalar_one()
|
||||
|
||||
return {
|
||||
"instance": instance,
|
||||
"agent_type_name": instance.agent_type.name if instance.agent_type else None,
|
||||
"agent_type_slug": instance.agent_type.slug if instance.agent_type else None,
|
||||
"project_name": instance.project.name if instance.project else None,
|
||||
"project_slug": instance.project.slug if instance.project else None,
|
||||
"assigned_issues_count": assigned_issues_count,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent instance with details {instance_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
status: AgentStatus | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[AgentInstance], int]:
|
||||
"""Get agent instances for a specific project."""
|
||||
try:
|
||||
query = select(AgentInstance).where(
|
||||
AgentInstance.project_id == project_id
|
||||
)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(AgentInstance.status == status)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply pagination
|
||||
query = query.order_by(AgentInstance.created_at.desc())
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
instances = list(result.scalars().all())
|
||||
|
||||
return instances, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting instances by project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_agent_type(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
agent_type_id: UUID,
|
||||
status: AgentStatus | None = None,
|
||||
) -> list[AgentInstance]:
|
||||
"""Get all instances of a specific agent type."""
|
||||
try:
|
||||
query = select(AgentInstance).where(
|
||||
AgentInstance.agent_type_id == agent_type_id
|
||||
)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(AgentInstance.status == status)
|
||||
|
||||
query = query.order_by(AgentInstance.created_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting instances by agent type {agent_type_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
status: AgentStatus,
|
||||
current_task: str | None = None,
|
||||
) -> AgentInstance | None:
|
||||
"""Update the status of an agent instance."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentInstance).where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
instance.status = status
|
||||
instance.last_activity_at = datetime.now(UTC)
|
||||
if current_task is not None:
|
||||
instance.current_task = current_task
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(instance)
|
||||
return instance
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error updating instance status {instance_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def terminate(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
) -> AgentInstance | None:
|
||||
"""Terminate an agent instance."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentInstance).where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
instance.status = AgentStatus.TERMINATED
|
||||
instance.terminated_at = datetime.now(UTC)
|
||||
instance.current_task = None
|
||||
instance.session_id = None
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(instance)
|
||||
return instance
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error terminating instance {instance_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def record_task_completion(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
instance_id: UUID,
|
||||
tokens_used: int,
|
||||
cost_incurred: Decimal,
|
||||
) -> AgentInstance | None:
|
||||
"""Record a completed task and update metrics."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentInstance).where(AgentInstance.id == instance_id)
|
||||
)
|
||||
instance = result.scalar_one_or_none()
|
||||
|
||||
if not instance:
|
||||
return None
|
||||
|
||||
instance.tasks_completed += 1
|
||||
instance.tokens_used += tokens_used
|
||||
instance.cost_incurred += cost_incurred
|
||||
instance.last_activity_at = datetime.now(UTC)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(instance)
|
||||
return instance
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error recording task completion {instance_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_project_metrics(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> dict[str, Any]:
|
||||
"""Get aggregated metrics for all agents in a project."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(
|
||||
func.count(AgentInstance.id).label("total_instances"),
|
||||
func.count(AgentInstance.id)
|
||||
.filter(AgentInstance.status == AgentStatus.WORKING)
|
||||
.label("active_instances"),
|
||||
func.count(AgentInstance.id)
|
||||
.filter(AgentInstance.status == AgentStatus.IDLE)
|
||||
.label("idle_instances"),
|
||||
func.sum(AgentInstance.tasks_completed).label("total_tasks"),
|
||||
func.sum(AgentInstance.tokens_used).label("total_tokens"),
|
||||
func.sum(AgentInstance.cost_incurred).label("total_cost"),
|
||||
).where(AgentInstance.project_id == project_id)
|
||||
)
|
||||
row = result.one()
|
||||
|
||||
return {
|
||||
"total_instances": row.total_instances or 0,
|
||||
"active_instances": row.active_instances or 0,
|
||||
"idle_instances": row.idle_instances or 0,
|
||||
"total_tasks_completed": row.total_tasks or 0,
|
||||
"total_tokens_used": row.total_tokens or 0,
|
||||
"total_cost_incurred": row.total_cost or Decimal("0.0000"),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting project metrics {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def bulk_terminate_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> int:
|
||||
"""Terminate all active instances in a project."""
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
stmt = (
|
||||
update(AgentInstance)
|
||||
.where(
|
||||
AgentInstance.project_id == project_id,
|
||||
AgentInstance.status != AgentStatus.TERMINATED,
|
||||
)
|
||||
.values(
|
||||
status=AgentStatus.TERMINATED,
|
||||
terminated_at=now,
|
||||
current_task=None,
|
||||
session_id=None,
|
||||
updated_at=now,
|
||||
)
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
terminated_count = result.rowcount
|
||||
logger.info(
|
||||
f"Bulk terminated {terminated_count} instances in project {project_id}"
|
||||
)
|
||||
return terminated_count
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error bulk terminating instances for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
agent_instance = CRUDAgentInstance(AgentInstance)
|
||||
275
backend/app/crud/syndarix/agent_type.py
Normal file
275
backend/app/crud/syndarix/agent_type.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# app/crud/syndarix/agent_type.py
|
||||
"""Async CRUD operations for AgentType model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, AgentType
|
||||
from app.schemas.syndarix import AgentTypeCreate, AgentTypeUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
"""Async CRUD operations for AgentType model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> AgentType | None:
|
||||
"""Get agent type by slug."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentType).where(AgentType.slug == slug)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent type by slug {slug}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(
|
||||
self, db: AsyncSession, *, obj_in: AgentTypeCreate
|
||||
) -> AgentType:
|
||||
"""Create a new agent type with error handling."""
|
||||
try:
|
||||
db_obj = AgentType(
|
||||
name=obj_in.name,
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
expertise=obj_in.expertise,
|
||||
personality_prompt=obj_in.personality_prompt,
|
||||
primary_model=obj_in.primary_model,
|
||||
fallback_models=obj_in.fallback_models,
|
||||
model_params=obj_in.model_params,
|
||||
mcp_servers=obj_in.mcp_servers,
|
||||
tool_permissions=obj_in.tool_permissions,
|
||||
is_active=obj_in.is_active,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(
|
||||
f"Agent type with slug '{obj_in.slug}' already exists"
|
||||
)
|
||||
logger.error(f"Integrity error creating agent type: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Unexpected error creating agent type: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[AgentType], int]:
|
||||
"""
|
||||
Get multiple agent types with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (agent types list, total count)
|
||||
"""
|
||||
try:
|
||||
query = select(AgentType)
|
||||
|
||||
# Apply filters
|
||||
if is_active is not None:
|
||||
query = query.where(AgentType.is_active == is_active)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
AgentType.name.ilike(f"%{search}%"),
|
||||
AgentType.slug.ilike(f"%{search}%"),
|
||||
AgentType.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(AgentType, sort_by, AgentType.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
agent_types = list(result.scalars().all())
|
||||
|
||||
return agent_types, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent types with filters: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_with_instance_count(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
agent_type_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a single agent type with its instance count.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent_type and instance_count
|
||||
"""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentType).where(AgentType.id == agent_type_id)
|
||||
)
|
||||
agent_type = result.scalar_one_or_none()
|
||||
|
||||
if not agent_type:
|
||||
return None
|
||||
|
||||
# Get instance count
|
||||
count_result = await db.execute(
|
||||
select(func.count(AgentInstance.id)).where(
|
||||
AgentInstance.agent_type_id == agent_type_id
|
||||
)
|
||||
)
|
||||
instance_count = count_result.scalar_one()
|
||||
|
||||
return {
|
||||
"agent_type": agent_type,
|
||||
"instance_count": instance_count,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent type with count {agent_type_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_instance_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get agent types with instance counts in optimized queries.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with agent_type and instance_count, total count)
|
||||
"""
|
||||
try:
|
||||
# Get filtered agent types
|
||||
agent_types, total = await self.get_multi_with_filters(
|
||||
db,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
search=search,
|
||||
)
|
||||
|
||||
if not agent_types:
|
||||
return [], 0
|
||||
|
||||
agent_type_ids = [at.id for at in agent_types]
|
||||
|
||||
# Get instance counts in bulk
|
||||
counts_result = await db.execute(
|
||||
select(
|
||||
AgentInstance.agent_type_id,
|
||||
func.count(AgentInstance.id).label("count"),
|
||||
)
|
||||
.where(AgentInstance.agent_type_id.in_(agent_type_ids))
|
||||
.group_by(AgentInstance.agent_type_id)
|
||||
)
|
||||
counts = {row.agent_type_id: row.count for row in counts_result}
|
||||
|
||||
# Combine results
|
||||
results = [
|
||||
{
|
||||
"agent_type": agent_type,
|
||||
"instance_count": counts.get(agent_type.id, 0),
|
||||
}
|
||||
for agent_type in agent_types
|
||||
]
|
||||
|
||||
return results, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent types with counts: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_expertise(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
expertise: str,
|
||||
is_active: bool = True,
|
||||
) -> list[AgentType]:
|
||||
"""Get agent types that have a specific expertise."""
|
||||
try:
|
||||
# Use PostgreSQL JSONB contains operator
|
||||
query = select(AgentType).where(
|
||||
AgentType.expertise.contains([expertise.lower()]),
|
||||
AgentType.is_active == is_active,
|
||||
)
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting agent types by expertise {expertise}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def deactivate(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
agent_type_id: UUID,
|
||||
) -> AgentType | None:
|
||||
"""Deactivate an agent type (soft delete)."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(AgentType).where(AgentType.id == agent_type_id)
|
||||
)
|
||||
agent_type = result.scalar_one_or_none()
|
||||
|
||||
if not agent_type:
|
||||
return None
|
||||
|
||||
agent_type.is_active = False
|
||||
await db.commit()
|
||||
await db.refresh(agent_type)
|
||||
return agent_type
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error deactivating agent type {agent_type_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
agent_type = CRUDAgentType(AgentType)
|
||||
437
backend/app/crud/syndarix/issue.py
Normal file
437
backend/app/crud/syndarix/issue.py
Normal file
@@ -0,0 +1,437 @@
|
||||
# app/crud/syndarix/issue.py
|
||||
"""Async CRUD operations for Issue model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, Issue
|
||||
from app.models.syndarix.enums import IssuePriority, IssueStatus, SyncStatus
|
||||
from app.schemas.syndarix import IssueCreate, IssueUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDIssue(CRUDBase[Issue, IssueCreate, IssueUpdate]):
|
||||
"""Async CRUD operations for Issue model."""
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: IssueCreate) -> Issue:
|
||||
"""Create a new issue with error handling."""
|
||||
try:
|
||||
db_obj = Issue(
|
||||
project_id=obj_in.project_id,
|
||||
title=obj_in.title,
|
||||
body=obj_in.body,
|
||||
status=obj_in.status,
|
||||
priority=obj_in.priority,
|
||||
labels=obj_in.labels,
|
||||
assigned_agent_id=obj_in.assigned_agent_id,
|
||||
human_assignee=obj_in.human_assignee,
|
||||
sprint_id=obj_in.sprint_id,
|
||||
story_points=obj_in.story_points,
|
||||
external_tracker=obj_in.external_tracker,
|
||||
external_id=obj_in.external_id,
|
||||
external_url=obj_in.external_url,
|
||||
external_number=obj_in.external_number,
|
||||
sync_status=SyncStatus.SYNCED,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error creating issue: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating issue: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get an issue with full details including related entity names.
|
||||
|
||||
Returns:
|
||||
Dictionary with issue and related entity details
|
||||
"""
|
||||
try:
|
||||
# Get issue with joined relationships
|
||||
result = await db.execute(
|
||||
select(Issue)
|
||||
.options(
|
||||
joinedload(Issue.project),
|
||||
joinedload(Issue.sprint),
|
||||
joinedload(Issue.assigned_agent).joinedload(AgentInstance.agent_type),
|
||||
)
|
||||
.where(Issue.id == issue_id)
|
||||
)
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
return {
|
||||
"issue": issue,
|
||||
"project_name": issue.project.name if issue.project else None,
|
||||
"project_slug": issue.project.slug if issue.project else None,
|
||||
"sprint_name": issue.sprint.name if issue.sprint else None,
|
||||
"assigned_agent_type_name": (
|
||||
issue.assigned_agent.agent_type.name
|
||||
if issue.assigned_agent and issue.assigned_agent.agent_type
|
||||
else None
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issue with details {issue_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
status: IssueStatus | None = None,
|
||||
priority: IssuePriority | None = None,
|
||||
sprint_id: UUID | None = None,
|
||||
assigned_agent_id: UUID | None = None,
|
||||
labels: list[str] | None = None,
|
||||
search: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Issue], int]:
|
||||
"""Get issues for a specific project with filters."""
|
||||
try:
|
||||
query = select(Issue).where(Issue.project_id == project_id)
|
||||
|
||||
# Apply filters
|
||||
if status is not None:
|
||||
query = query.where(Issue.status == status)
|
||||
|
||||
if priority is not None:
|
||||
query = query.where(Issue.priority == priority)
|
||||
|
||||
if sprint_id is not None:
|
||||
query = query.where(Issue.sprint_id == sprint_id)
|
||||
|
||||
if assigned_agent_id is not None:
|
||||
query = query.where(Issue.assigned_agent_id == assigned_agent_id)
|
||||
|
||||
if labels:
|
||||
# Match any of the provided labels
|
||||
for label in labels:
|
||||
query = query.where(Issue.labels.contains([label.lower()]))
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Issue.title.ilike(f"%{search}%"),
|
||||
Issue.body.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Issue, sort_by, Issue.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
issues = list(result.scalars().all())
|
||||
|
||||
return issues, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issues by project {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
status: IssueStatus | None = None,
|
||||
) -> list[Issue]:
|
||||
"""Get all issues in a sprint."""
|
||||
try:
|
||||
query = select(Issue).where(Issue.sprint_id == sprint_id)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(Issue.status == status)
|
||||
|
||||
query = query.order_by(Issue.priority.desc(), Issue.created_at.asc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issues by sprint {sprint_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def assign_to_agent(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
agent_id: UUID | None,
|
||||
) -> Issue | None:
|
||||
"""Assign an issue to an agent (or unassign if agent_id is None)."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.assigned_agent_id = agent_id
|
||||
issue.human_assignee = None # Clear human assignee when assigning to agent
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error assigning issue {issue_id} to agent {agent_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def assign_to_human(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
human_assignee: str | None,
|
||||
) -> Issue | None:
|
||||
"""Assign an issue to a human (or unassign if human_assignee is None)."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.human_assignee = human_assignee
|
||||
issue.assigned_agent_id = None # Clear agent when assigning to human
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error assigning issue {issue_id} to human {human_assignee}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def close_issue(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> Issue | None:
|
||||
"""Close an issue by setting status and closed_at timestamp."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.status = IssueStatus.CLOSED
|
||||
issue.closed_at = datetime.now(UTC)
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error closing issue {issue_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def reopen_issue(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
) -> Issue | None:
|
||||
"""Reopen a closed issue."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.status = IssueStatus.OPEN
|
||||
issue.closed_at = None
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error reopening issue {issue_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_sync_status(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
issue_id: UUID,
|
||||
sync_status: SyncStatus,
|
||||
last_synced_at: datetime | None = None,
|
||||
external_updated_at: datetime | None = None,
|
||||
) -> Issue | None:
|
||||
"""Update the sync status of an issue."""
|
||||
try:
|
||||
result = await db.execute(select(Issue).where(Issue.id == issue_id))
|
||||
issue = result.scalar_one_or_none()
|
||||
|
||||
if not issue:
|
||||
return None
|
||||
|
||||
issue.sync_status = sync_status
|
||||
if last_synced_at:
|
||||
issue.last_synced_at = last_synced_at
|
||||
if external_updated_at:
|
||||
issue.external_updated_at = external_updated_at
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(issue)
|
||||
return issue
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error updating sync status for issue {issue_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_project_stats(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> dict[str, Any]:
|
||||
"""Get issue statistics for a project."""
|
||||
try:
|
||||
# Get counts by status
|
||||
status_counts = await db.execute(
|
||||
select(Issue.status, func.count(Issue.id).label("count"))
|
||||
.where(Issue.project_id == project_id)
|
||||
.group_by(Issue.status)
|
||||
)
|
||||
by_status = {row.status.value: row.count for row in status_counts}
|
||||
|
||||
# Get counts by priority
|
||||
priority_counts = await db.execute(
|
||||
select(Issue.priority, func.count(Issue.id).label("count"))
|
||||
.where(Issue.project_id == project_id)
|
||||
.group_by(Issue.priority)
|
||||
)
|
||||
by_priority = {row.priority.value: row.count for row in priority_counts}
|
||||
|
||||
# Get story points
|
||||
points_result = await db.execute(
|
||||
select(
|
||||
func.sum(Issue.story_points).label("total"),
|
||||
func.sum(Issue.story_points)
|
||||
.filter(Issue.status == IssueStatus.CLOSED)
|
||||
.label("completed"),
|
||||
).where(Issue.project_id == project_id)
|
||||
)
|
||||
points_row = points_result.one()
|
||||
|
||||
total_issues = sum(by_status.values())
|
||||
|
||||
return {
|
||||
"total": total_issues,
|
||||
"open": by_status.get("open", 0),
|
||||
"in_progress": by_status.get("in_progress", 0),
|
||||
"in_review": by_status.get("in_review", 0),
|
||||
"blocked": by_status.get("blocked", 0),
|
||||
"closed": by_status.get("closed", 0),
|
||||
"by_priority": by_priority,
|
||||
"total_story_points": points_row.total,
|
||||
"completed_story_points": points_row.completed,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issue stats for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_external_id(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
external_tracker: str,
|
||||
external_id: str,
|
||||
) -> Issue | None:
|
||||
"""Get an issue by its external tracker ID."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Issue).where(
|
||||
Issue.external_tracker == external_tracker,
|
||||
Issue.external_id == external_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting issue by external ID {external_tracker}:{external_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_pending_sync(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[Issue]:
|
||||
"""Get issues that need to be synced with external tracker."""
|
||||
try:
|
||||
query = select(Issue).where(
|
||||
Issue.external_tracker.isnot(None),
|
||||
Issue.sync_status.in_([SyncStatus.PENDING, SyncStatus.ERROR]),
|
||||
)
|
||||
|
||||
if project_id:
|
||||
query = query.where(Issue.project_id == project_id)
|
||||
|
||||
query = query.order_by(Issue.updated_at.asc()).limit(limit)
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting pending sync issues: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
issue = CRUDIssue(Issue)
|
||||
309
backend/app/crud/syndarix/project.py
Normal file
309
backend/app/crud/syndarix/project.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# app/crud/syndarix/project.py
|
||||
"""Async CRUD operations for Project model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import AgentInstance, Issue, Project, Sprint
|
||||
from app.models.syndarix.enums import ProjectStatus, SprintStatus
|
||||
from app.schemas.syndarix import ProjectCreate, ProjectUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDProject(CRUDBase[Project, ProjectCreate, ProjectUpdate]):
|
||||
"""Async CRUD operations for Project model."""
|
||||
|
||||
async def get_by_slug(self, db: AsyncSession, *, slug: str) -> Project | None:
|
||||
"""Get project by slug."""
|
||||
try:
|
||||
result = await db.execute(select(Project).where(Project.slug == slug))
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting project by slug {slug}: {e!s}")
|
||||
raise
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: ProjectCreate) -> Project:
|
||||
"""Create a new project with error handling."""
|
||||
try:
|
||||
db_obj = Project(
|
||||
name=obj_in.name,
|
||||
slug=obj_in.slug,
|
||||
description=obj_in.description,
|
||||
autonomy_level=obj_in.autonomy_level,
|
||||
status=obj_in.status,
|
||||
settings=obj_in.settings or {},
|
||||
owner_id=obj_in.owner_id,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
if "slug" in error_msg.lower():
|
||||
logger.warning(f"Duplicate slug attempted: {obj_in.slug}")
|
||||
raise ValueError(f"Project with slug '{obj_in.slug}' already exists")
|
||||
logger.error(f"Integrity error creating project: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating project: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_multi_with_filters(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
status: ProjectStatus | None = None,
|
||||
owner_id: UUID | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
) -> tuple[list[Project], int]:
|
||||
"""
|
||||
Get multiple projects with filtering, searching, and sorting.
|
||||
|
||||
Returns:
|
||||
Tuple of (projects list, total count)
|
||||
"""
|
||||
try:
|
||||
query = select(Project)
|
||||
|
||||
# Apply filters
|
||||
if status is not None:
|
||||
query = query.where(Project.status == status)
|
||||
|
||||
if owner_id is not None:
|
||||
query = query.where(Project.owner_id == owner_id)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
Project.name.ilike(f"%{search}%"),
|
||||
Project.slug.ilike(f"%{search}%"),
|
||||
Project.description.ilike(f"%{search}%"),
|
||||
)
|
||||
query = query.where(search_filter)
|
||||
|
||||
# Get total count before pagination
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting
|
||||
sort_column = getattr(Project, sort_by, Project.created_at)
|
||||
if sort_order == "desc":
|
||||
query = query.order_by(sort_column.desc())
|
||||
else:
|
||||
query = query.order_by(sort_column.asc())
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
projects = list(result.scalars().all())
|
||||
|
||||
return projects, total
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting projects with filters: {e!s}")
|
||||
raise
|
||||
|
||||
async def get_with_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a single project with agent and issue counts.
|
||||
|
||||
Returns:
|
||||
Dictionary with project, agent_count, issue_count, active_sprint_name
|
||||
"""
|
||||
try:
|
||||
# Get project
|
||||
result = await db.execute(select(Project).where(Project.id == project_id))
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
return None
|
||||
|
||||
# Get agent count
|
||||
agent_count_result = await db.execute(
|
||||
select(func.count(AgentInstance.id)).where(
|
||||
AgentInstance.project_id == project_id
|
||||
)
|
||||
)
|
||||
agent_count = agent_count_result.scalar_one()
|
||||
|
||||
# Get issue count
|
||||
issue_count_result = await db.execute(
|
||||
select(func.count(Issue.id)).where(Issue.project_id == project_id)
|
||||
)
|
||||
issue_count = issue_count_result.scalar_one()
|
||||
|
||||
# Get active sprint name
|
||||
active_sprint_result = await db.execute(
|
||||
select(Sprint.name).where(
|
||||
Sprint.project_id == project_id,
|
||||
Sprint.status == SprintStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
active_sprint_name = active_sprint_result.scalar_one_or_none()
|
||||
|
||||
return {
|
||||
"project": project,
|
||||
"agent_count": agent_count,
|
||||
"issue_count": issue_count,
|
||||
"active_sprint_name": active_sprint_name,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting project with counts {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_multi_with_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
status: ProjectStatus | None = None,
|
||||
owner_id: UUID | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
Get projects with agent/issue counts in optimized queries.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of dicts with project and counts, total count)
|
||||
"""
|
||||
try:
|
||||
# Get filtered projects
|
||||
projects, total = await self.get_multi_with_filters(
|
||||
db,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
status=status,
|
||||
owner_id=owner_id,
|
||||
search=search,
|
||||
)
|
||||
|
||||
if not projects:
|
||||
return [], 0
|
||||
|
||||
project_ids = [p.id for p in projects]
|
||||
|
||||
# Get agent counts in bulk
|
||||
agent_counts_result = await db.execute(
|
||||
select(
|
||||
AgentInstance.project_id,
|
||||
func.count(AgentInstance.id).label("count"),
|
||||
)
|
||||
.where(AgentInstance.project_id.in_(project_ids))
|
||||
.group_by(AgentInstance.project_id)
|
||||
)
|
||||
agent_counts = {row.project_id: row.count for row in agent_counts_result}
|
||||
|
||||
# Get issue counts in bulk
|
||||
issue_counts_result = await db.execute(
|
||||
select(
|
||||
Issue.project_id,
|
||||
func.count(Issue.id).label("count"),
|
||||
)
|
||||
.where(Issue.project_id.in_(project_ids))
|
||||
.group_by(Issue.project_id)
|
||||
)
|
||||
issue_counts = {row.project_id: row.count for row in issue_counts_result}
|
||||
|
||||
# Get active sprint names
|
||||
active_sprints_result = await db.execute(
|
||||
select(Sprint.project_id, Sprint.name).where(
|
||||
Sprint.project_id.in_(project_ids),
|
||||
Sprint.status == SprintStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
active_sprints = {
|
||||
row.project_id: row.name for row in active_sprints_result
|
||||
}
|
||||
|
||||
# Combine results
|
||||
results = [
|
||||
{
|
||||
"project": project,
|
||||
"agent_count": agent_counts.get(project.id, 0),
|
||||
"issue_count": issue_counts.get(project.id, 0),
|
||||
"active_sprint_name": active_sprints.get(project.id),
|
||||
}
|
||||
for project in projects
|
||||
]
|
||||
|
||||
return results, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting projects with counts: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_projects_by_owner(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
owner_id: UUID,
|
||||
status: ProjectStatus | None = None,
|
||||
) -> list[Project]:
|
||||
"""Get all projects owned by a specific user."""
|
||||
try:
|
||||
query = select(Project).where(Project.owner_id == owner_id)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(Project.status == status)
|
||||
|
||||
query = query.order_by(Project.created_at.desc())
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting projects by owner {owner_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def archive_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> Project | None:
|
||||
"""Archive a project by setting status to ARCHIVED."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Project).where(Project.id == project_id)
|
||||
)
|
||||
project = result.scalar_one_or_none()
|
||||
|
||||
if not project:
|
||||
return None
|
||||
|
||||
project.status = ProjectStatus.ARCHIVED
|
||||
await db.commit()
|
||||
await db.refresh(project)
|
||||
return project
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(
|
||||
f"Error archiving project {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
project = CRUDProject(Project)
|
||||
406
backend/app/crud/syndarix/sprint.py
Normal file
406
backend/app/crud/syndarix/sprint.py
Normal file
@@ -0,0 +1,406 @@
|
||||
# app/crud/syndarix/sprint.py
|
||||
"""Async CRUD operations for Sprint model using SQLAlchemy 2.0 patterns."""
|
||||
|
||||
import logging
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
|
||||
from app.crud.base import CRUDBase
|
||||
from app.models.syndarix import Issue, Sprint
|
||||
from app.models.syndarix.enums import IssueStatus, SprintStatus
|
||||
from app.schemas.syndarix import SprintCreate, SprintUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CRUDSprint(CRUDBase[Sprint, SprintCreate, SprintUpdate]):
|
||||
"""Async CRUD operations for Sprint model."""
|
||||
|
||||
async def create(self, db: AsyncSession, *, obj_in: SprintCreate) -> Sprint:
|
||||
"""Create a new sprint with error handling."""
|
||||
try:
|
||||
db_obj = Sprint(
|
||||
project_id=obj_in.project_id,
|
||||
name=obj_in.name,
|
||||
number=obj_in.number,
|
||||
goal=obj_in.goal,
|
||||
start_date=obj_in.start_date,
|
||||
end_date=obj_in.end_date,
|
||||
status=obj_in.status,
|
||||
planned_points=obj_in.planned_points,
|
||||
completed_points=obj_in.completed_points,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
await db.refresh(db_obj)
|
||||
return db_obj
|
||||
except IntegrityError as e:
|
||||
await db.rollback()
|
||||
error_msg = str(e.orig) if hasattr(e, "orig") else str(e)
|
||||
logger.error(f"Integrity error creating sprint: {error_msg}")
|
||||
raise ValueError(f"Database integrity error: {error_msg}")
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Unexpected error creating sprint: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_with_details(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get a sprint with full details including issue counts.
|
||||
|
||||
Returns:
|
||||
Dictionary with sprint and related details
|
||||
"""
|
||||
try:
|
||||
# Get sprint with joined project
|
||||
result = await db.execute(
|
||||
select(Sprint)
|
||||
.options(joinedload(Sprint.project))
|
||||
.where(Sprint.id == sprint_id)
|
||||
)
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
# Get issue counts
|
||||
issue_counts = await db.execute(
|
||||
select(
|
||||
func.count(Issue.id).label("total"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.OPEN)
|
||||
.label("open"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.CLOSED)
|
||||
.label("completed"),
|
||||
).where(Issue.sprint_id == sprint_id)
|
||||
)
|
||||
counts = issue_counts.one()
|
||||
|
||||
return {
|
||||
"sprint": sprint,
|
||||
"project_name": sprint.project.name if sprint.project else None,
|
||||
"project_slug": sprint.project.slug if sprint.project else None,
|
||||
"issue_count": counts.total,
|
||||
"open_issues": counts.open,
|
||||
"completed_issues": counts.completed,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting sprint with details {sprint_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_by_project(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
status: SprintStatus | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[Sprint], int]:
|
||||
"""Get sprints for a specific project."""
|
||||
try:
|
||||
query = select(Sprint).where(Sprint.project_id == project_id)
|
||||
|
||||
if status is not None:
|
||||
query = query.where(Sprint.status == status)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count()).select_from(query.alias())
|
||||
count_result = await db.execute(count_query)
|
||||
total = count_result.scalar_one()
|
||||
|
||||
# Apply sorting (by number descending - newest first)
|
||||
query = query.order_by(Sprint.number.desc())
|
||||
query = query.offset(skip).limit(limit)
|
||||
result = await db.execute(query)
|
||||
sprints = list(result.scalars().all())
|
||||
|
||||
return sprints, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting sprints by project {project_id}: {e!s}", exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_active_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> Sprint | None:
|
||||
"""Get the currently active sprint for a project."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Sprint).where(
|
||||
Sprint.project_id == project_id,
|
||||
Sprint.status == SprintStatus.ACTIVE,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting active sprint for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_next_sprint_number(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
) -> int:
|
||||
"""Get the next sprint number for a project."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(func.max(Sprint.number)).where(Sprint.project_id == project_id)
|
||||
)
|
||||
max_number = result.scalar_one_or_none()
|
||||
return (max_number or 0) + 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting next sprint number for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def start_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
start_date: date | None = None,
|
||||
) -> Sprint | None:
|
||||
"""Start a planned sprint."""
|
||||
try:
|
||||
result = await db.execute(select(Sprint).where(Sprint.id == sprint_id))
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
if sprint.status != SprintStatus.PLANNED:
|
||||
raise ValueError(
|
||||
f"Cannot start sprint with status {sprint.status.value}"
|
||||
)
|
||||
|
||||
# Check for existing active sprint in project
|
||||
active_sprint = await self.get_active_sprint(db, project_id=sprint.project_id)
|
||||
if active_sprint:
|
||||
raise ValueError(
|
||||
f"Project already has an active sprint: {active_sprint.name}"
|
||||
)
|
||||
|
||||
sprint.status = SprintStatus.ACTIVE
|
||||
if start_date:
|
||||
sprint.start_date = start_date
|
||||
|
||||
# Calculate planned points from issues
|
||||
points_result = await db.execute(
|
||||
select(func.sum(Issue.story_points)).where(Issue.sprint_id == sprint_id)
|
||||
)
|
||||
sprint.planned_points = points_result.scalar_one_or_none() or 0
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(sprint)
|
||||
return sprint
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error starting sprint {sprint_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def complete_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
) -> Sprint | None:
|
||||
"""Complete an active sprint and calculate completed points."""
|
||||
try:
|
||||
result = await db.execute(select(Sprint).where(Sprint.id == sprint_id))
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
if sprint.status != SprintStatus.ACTIVE:
|
||||
raise ValueError(
|
||||
f"Cannot complete sprint with status {sprint.status.value}"
|
||||
)
|
||||
|
||||
sprint.status = SprintStatus.COMPLETED
|
||||
|
||||
# Calculate completed points from closed issues
|
||||
points_result = await db.execute(
|
||||
select(func.sum(Issue.story_points)).where(
|
||||
Issue.sprint_id == sprint_id,
|
||||
Issue.status == IssueStatus.CLOSED,
|
||||
)
|
||||
)
|
||||
sprint.completed_points = points_result.scalar_one_or_none() or 0
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(sprint)
|
||||
return sprint
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error completing sprint {sprint_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def cancel_sprint(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
sprint_id: UUID,
|
||||
) -> Sprint | None:
|
||||
"""Cancel a sprint (only PLANNED or ACTIVE sprints can be cancelled)."""
|
||||
try:
|
||||
result = await db.execute(select(Sprint).where(Sprint.id == sprint_id))
|
||||
sprint = result.scalar_one_or_none()
|
||||
|
||||
if not sprint:
|
||||
return None
|
||||
|
||||
if sprint.status not in [SprintStatus.PLANNED, SprintStatus.ACTIVE]:
|
||||
raise ValueError(
|
||||
f"Cannot cancel sprint with status {sprint.status.value}"
|
||||
)
|
||||
|
||||
sprint.status = SprintStatus.CANCELLED
|
||||
await db.commit()
|
||||
await db.refresh(sprint)
|
||||
return sprint
|
||||
except ValueError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error(f"Error cancelling sprint {sprint_id}: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_velocity(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
limit: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get velocity data for completed sprints."""
|
||||
try:
|
||||
result = await db.execute(
|
||||
select(Sprint)
|
||||
.where(
|
||||
Sprint.project_id == project_id,
|
||||
Sprint.status == SprintStatus.COMPLETED,
|
||||
)
|
||||
.order_by(Sprint.number.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
sprints = list(result.scalars().all())
|
||||
|
||||
velocity_data = []
|
||||
for sprint in reversed(sprints): # Return in chronological order
|
||||
velocity = None
|
||||
if sprint.planned_points and sprint.planned_points > 0:
|
||||
velocity = (sprint.completed_points or 0) / sprint.planned_points
|
||||
velocity_data.append(
|
||||
{
|
||||
"sprint_number": sprint.number,
|
||||
"sprint_name": sprint.name,
|
||||
"planned_points": sprint.planned_points,
|
||||
"completed_points": sprint.completed_points,
|
||||
"velocity": velocity,
|
||||
}
|
||||
)
|
||||
|
||||
return velocity_data
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting velocity for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_sprints_with_issue_counts(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
project_id: UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""Get sprints with issue counts in optimized queries."""
|
||||
try:
|
||||
# Get sprints
|
||||
sprints, total = await self.get_by_project(
|
||||
db, project_id=project_id, skip=skip, limit=limit
|
||||
)
|
||||
|
||||
if not sprints:
|
||||
return [], 0
|
||||
|
||||
sprint_ids = [s.id for s in sprints]
|
||||
|
||||
# Get issue counts in bulk
|
||||
issue_counts = await db.execute(
|
||||
select(
|
||||
Issue.sprint_id,
|
||||
func.count(Issue.id).label("total"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.OPEN)
|
||||
.label("open"),
|
||||
func.count(Issue.id)
|
||||
.filter(Issue.status == IssueStatus.CLOSED)
|
||||
.label("completed"),
|
||||
)
|
||||
.where(Issue.sprint_id.in_(sprint_ids))
|
||||
.group_by(Issue.sprint_id)
|
||||
)
|
||||
counts_map = {
|
||||
row.sprint_id: {
|
||||
"issue_count": row.total,
|
||||
"open_issues": row.open,
|
||||
"completed_issues": row.completed,
|
||||
}
|
||||
for row in issue_counts
|
||||
}
|
||||
|
||||
# Combine results
|
||||
results = [
|
||||
{
|
||||
"sprint": sprint,
|
||||
**counts_map.get(
|
||||
sprint.id, {"issue_count": 0, "open_issues": 0, "completed_issues": 0}
|
||||
),
|
||||
}
|
||||
for sprint in sprints
|
||||
]
|
||||
|
||||
return results, total
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting sprints with counts for project {project_id}: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
sprint = CRUDSprint(Sprint)
|
||||
@@ -23,6 +23,15 @@ from .user import User
|
||||
from .user_organization import OrganizationRole, UserOrganization
|
||||
from .user_session import UserSession
|
||||
|
||||
# Syndarix domain models
|
||||
from .syndarix import (
|
||||
AgentInstance,
|
||||
AgentType,
|
||||
Issue,
|
||||
Project,
|
||||
Sprint,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"OAuthAccount",
|
||||
@@ -38,4 +47,10 @@ __all__ = [
|
||||
"User",
|
||||
"UserOrganization",
|
||||
"UserSession",
|
||||
# Syndarix models
|
||||
"AgentInstance",
|
||||
"AgentType",
|
||||
"Issue",
|
||||
"Project",
|
||||
"Sprint",
|
||||
]
|
||||
|
||||
41
backend/app/models/syndarix/__init__.py
Normal file
41
backend/app/models/syndarix/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# app/models/syndarix/__init__.py
|
||||
"""
|
||||
Syndarix domain models.
|
||||
|
||||
This package contains all the core entities for the Syndarix AI consulting platform:
|
||||
- Project: Client engagements with autonomy settings
|
||||
- AgentType: Templates for AI agent capabilities
|
||||
- AgentInstance: Spawned agents working on projects
|
||||
- Issue: Units of work with external tracker sync
|
||||
- Sprint: Time-boxed iterations for organizing work
|
||||
"""
|
||||
|
||||
from .agent_instance import AgentInstance
|
||||
from .agent_type import AgentType
|
||||
from .enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
from .issue import Issue
|
||||
from .project import Project
|
||||
from .sprint import Sprint
|
||||
|
||||
__all__ = [
|
||||
"AgentInstance",
|
||||
"AgentStatus",
|
||||
"AgentType",
|
||||
"AutonomyLevel",
|
||||
"Issue",
|
||||
"IssuePriority",
|
||||
"IssueStatus",
|
||||
"Project",
|
||||
"ProjectStatus",
|
||||
"Sprint",
|
||||
"SprintStatus",
|
||||
"SyncStatus",
|
||||
]
|
||||
108
backend/app/models/syndarix/agent_instance.py
Normal file
108
backend/app/models/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# app/models/syndarix/agent_instance.py
|
||||
"""
|
||||
AgentInstance model for Syndarix AI consulting platform.
|
||||
|
||||
An AgentInstance is a spawned instance of an AgentType, assigned to a
|
||||
specific project to perform work.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Column,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Index,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import AgentStatus
|
||||
|
||||
|
||||
class AgentInstance(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
AgentInstance model representing a spawned agent working on a project.
|
||||
|
||||
Tracks:
|
||||
- Current status and task
|
||||
- Memory (short-term in DB, long-term reference to vector store)
|
||||
- Session information for MCP connections
|
||||
- Usage metrics (tasks completed, tokens, cost)
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_instances"
|
||||
|
||||
# Foreign keys
|
||||
agent_type_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_types.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Status tracking
|
||||
status: Column[AgentStatus] = Column(
|
||||
Enum(AgentStatus),
|
||||
default=AgentStatus.IDLE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Current task description (brief summary of what agent is doing)
|
||||
current_task = Column(Text, nullable=True)
|
||||
|
||||
# Short-term memory stored in database (conversation context, recent decisions)
|
||||
short_term_memory = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# Reference to long-term memory in vector store (e.g., "project-123/agent-456")
|
||||
long_term_memory_ref = Column(String(500), nullable=True)
|
||||
|
||||
# Session ID for active MCP connections
|
||||
session_id = Column(String(255), nullable=True, index=True)
|
||||
|
||||
# Activity tracking
|
||||
last_activity_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
terminated_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Usage metrics
|
||||
tasks_completed = Column(Integer, default=0, nullable=False)
|
||||
tokens_used = Column(BigInteger, default=0, nullable=False)
|
||||
cost_incurred = Column(Numeric(precision=10, scale=4), default=0, nullable=False)
|
||||
|
||||
# Relationships
|
||||
agent_type = relationship("AgentType", back_populates="instances")
|
||||
project = relationship("Project", back_populates="agent_instances")
|
||||
assigned_issues = relationship(
|
||||
"Issue",
|
||||
back_populates="assigned_agent",
|
||||
foreign_keys="Issue.assigned_agent_id",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_agent_instances_project_status", "project_id", "status"),
|
||||
Index("ix_agent_instances_type_status", "agent_type_id", "status"),
|
||||
Index("ix_agent_instances_project_type", "project_id", "agent_type_id"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<AgentInstance {self.id} type={self.agent_type_id} "
|
||||
f"project={self.project_id} status={self.status.value}>"
|
||||
)
|
||||
72
backend/app/models/syndarix/agent_type.py
Normal file
72
backend/app/models/syndarix/agent_type.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# app/models/syndarix/agent_type.py
|
||||
"""
|
||||
AgentType model for Syndarix AI consulting platform.
|
||||
|
||||
An AgentType is a template that defines the capabilities, personality,
|
||||
and model configuration for agent instances.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
|
||||
class AgentType(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
AgentType model representing a template for agent instances.
|
||||
|
||||
Each agent type defines:
|
||||
- Expertise areas and personality prompt
|
||||
- Model configuration (primary, fallback, parameters)
|
||||
- MCP server access and tool permissions
|
||||
|
||||
Examples: ProductOwner, Architect, BackendEngineer, QAEngineer
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_types"
|
||||
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Areas of expertise for this agent type (e.g., ["python", "fastapi", "databases"])
|
||||
expertise = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# System prompt defining the agent's personality and behavior
|
||||
personality_prompt = Column(Text, nullable=False)
|
||||
|
||||
# Primary LLM model to use (e.g., "claude-opus-4-5-20251101")
|
||||
primary_model = Column(String(100), nullable=False)
|
||||
|
||||
# Fallback models in order of preference
|
||||
fallback_models = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Model parameters (temperature, max_tokens, etc.)
|
||||
model_params = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# List of MCP servers this agent can connect to
|
||||
mcp_servers = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Tool permissions configuration
|
||||
# Structure: {"allowed": ["*"], "denied": [], "require_approval": ["gitea:create_pr"]}
|
||||
tool_permissions = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# Whether this agent type is available for new instances
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Relationships
|
||||
instances = relationship(
|
||||
"AgentInstance",
|
||||
back_populates="agent_type",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_agent_types_slug_active", "slug", "is_active"),
|
||||
Index("ix_agent_types_name_active", "name", "is_active"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AgentType {self.name} ({self.slug}) active={self.is_active}>"
|
||||
123
backend/app/models/syndarix/enums.py
Normal file
123
backend/app/models/syndarix/enums.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# app/models/syndarix/enums.py
|
||||
"""
|
||||
Enums for Syndarix domain models.
|
||||
|
||||
These enums represent the core state machines and categorizations
|
||||
used throughout the Syndarix AI consulting platform.
|
||||
"""
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
|
||||
class AutonomyLevel(str, PyEnum):
|
||||
"""
|
||||
Defines how much control the human has over agent actions.
|
||||
|
||||
FULL_CONTROL: Human must approve every agent action
|
||||
MILESTONE: Human approves at sprint boundaries and major decisions
|
||||
AUTONOMOUS: Agents work independently, only escalating critical issues
|
||||
"""
|
||||
|
||||
FULL_CONTROL = "full_control"
|
||||
MILESTONE = "milestone"
|
||||
AUTONOMOUS = "autonomous"
|
||||
|
||||
|
||||
class ProjectStatus(str, PyEnum):
|
||||
"""
|
||||
Project lifecycle status.
|
||||
|
||||
ACTIVE: Project is actively being worked on
|
||||
PAUSED: Project is temporarily on hold
|
||||
COMPLETED: Project has been delivered successfully
|
||||
ARCHIVED: Project is no longer accessible for work
|
||||
"""
|
||||
|
||||
ACTIVE = "active"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class AgentStatus(str, PyEnum):
|
||||
"""
|
||||
Current operational status of an agent instance.
|
||||
|
||||
IDLE: Agent is available but not currently working
|
||||
WORKING: Agent is actively processing a task
|
||||
WAITING: Agent is waiting for external input or approval
|
||||
PAUSED: Agent has been manually paused
|
||||
TERMINATED: Agent instance has been shut down
|
||||
"""
|
||||
|
||||
IDLE = "idle"
|
||||
WORKING = "working"
|
||||
WAITING = "waiting"
|
||||
PAUSED = "paused"
|
||||
TERMINATED = "terminated"
|
||||
|
||||
|
||||
class IssueStatus(str, PyEnum):
|
||||
"""
|
||||
Issue workflow status.
|
||||
|
||||
OPEN: Issue is ready to be worked on
|
||||
IN_PROGRESS: Agent or human is actively working on the issue
|
||||
IN_REVIEW: Work is complete, awaiting review
|
||||
BLOCKED: Issue cannot proceed due to dependencies or blockers
|
||||
CLOSED: Issue has been completed or cancelled
|
||||
"""
|
||||
|
||||
OPEN = "open"
|
||||
IN_PROGRESS = "in_progress"
|
||||
IN_REVIEW = "in_review"
|
||||
BLOCKED = "blocked"
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class IssuePriority(str, PyEnum):
|
||||
"""
|
||||
Issue priority levels.
|
||||
|
||||
LOW: Nice to have, can be deferred
|
||||
MEDIUM: Standard priority, should be done
|
||||
HIGH: Important, should be prioritized
|
||||
CRITICAL: Must be done immediately, blocking other work
|
||||
"""
|
||||
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class SyncStatus(str, PyEnum):
|
||||
"""
|
||||
External issue tracker synchronization status.
|
||||
|
||||
SYNCED: Local and remote are in sync
|
||||
PENDING: Local changes waiting to be pushed
|
||||
CONFLICT: Merge conflict between local and remote
|
||||
ERROR: Synchronization failed due to an error
|
||||
"""
|
||||
|
||||
SYNCED = "synced"
|
||||
PENDING = "pending"
|
||||
CONFLICT = "conflict"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class SprintStatus(str, PyEnum):
|
||||
"""
|
||||
Sprint lifecycle status.
|
||||
|
||||
PLANNED: Sprint has been created but not started
|
||||
ACTIVE: Sprint is currently in progress
|
||||
COMPLETED: Sprint has been finished successfully
|
||||
CANCELLED: Sprint was cancelled before completion
|
||||
"""
|
||||
|
||||
PLANNED = "planned"
|
||||
ACTIVE = "active"
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
133
backend/app/models/syndarix/issue.py
Normal file
133
backend/app/models/syndarix/issue.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# app/models/syndarix/issue.py
|
||||
"""
|
||||
Issue model for Syndarix AI consulting platform.
|
||||
|
||||
An Issue represents a unit of work that can be assigned to agents or humans,
|
||||
with optional synchronization to external issue trackers (Gitea, GitHub, GitLab).
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Index, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import IssuePriority, IssueStatus, SyncStatus
|
||||
|
||||
|
||||
class Issue(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Issue model representing a unit of work in a project.
|
||||
|
||||
Features:
|
||||
- Standard issue fields (title, body, status, priority)
|
||||
- Assignment to agent instances or human assignees
|
||||
- Sprint association for backlog management
|
||||
- External tracker synchronization (Gitea, GitHub, GitLab)
|
||||
"""
|
||||
|
||||
__tablename__ = "issues"
|
||||
|
||||
# Foreign key to project
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Issue content
|
||||
title = Column(String(500), nullable=False)
|
||||
body = Column(Text, nullable=False, default="")
|
||||
|
||||
# Status and priority
|
||||
status: Column[IssueStatus] = Column(
|
||||
Enum(IssueStatus),
|
||||
default=IssueStatus.OPEN,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
priority: Column[IssuePriority] = Column(
|
||||
Enum(IssuePriority),
|
||||
default=IssuePriority.MEDIUM,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Labels for categorization (e.g., ["bug", "frontend", "urgent"])
|
||||
labels = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Assignment - either to an agent or a human (mutually exclusive)
|
||||
assigned_agent_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("agent_instances.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Human assignee (username or email, not a FK to allow external users)
|
||||
human_assignee = Column(String(255), nullable=True, index=True)
|
||||
|
||||
# Sprint association
|
||||
sprint_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("sprints.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Story points for estimation
|
||||
story_points = Column(Integer, nullable=True)
|
||||
|
||||
# External tracker integration
|
||||
external_tracker = Column(
|
||||
String(50),
|
||||
nullable=True,
|
||||
index=True,
|
||||
) # 'gitea', 'github', 'gitlab'
|
||||
|
||||
external_id = Column(String(255), nullable=True) # External system's ID
|
||||
external_url = Column(String(1000), nullable=True) # Link to external issue
|
||||
external_number = Column(Integer, nullable=True) # Issue number (e.g., #123)
|
||||
|
||||
# Sync status with external tracker
|
||||
sync_status: Column[SyncStatus] = Column(
|
||||
Enum(SyncStatus),
|
||||
default=SyncStatus.SYNCED,
|
||||
nullable=False,
|
||||
# Note: Index defined in __table_args__ as ix_issues_sync_status
|
||||
)
|
||||
|
||||
last_synced_at = Column(DateTime(timezone=True), nullable=True)
|
||||
external_updated_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Lifecycle timestamp
|
||||
closed_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="issues")
|
||||
assigned_agent = relationship(
|
||||
"AgentInstance",
|
||||
back_populates="assigned_issues",
|
||||
foreign_keys=[assigned_agent_id],
|
||||
)
|
||||
sprint = relationship("Sprint", back_populates="issues")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_issues_project_status", "project_id", "status"),
|
||||
Index("ix_issues_project_priority", "project_id", "priority"),
|
||||
Index("ix_issues_project_sprint", "project_id", "sprint_id"),
|
||||
Index("ix_issues_external_tracker_id", "external_tracker", "external_id"),
|
||||
Index("ix_issues_sync_status", "sync_status"),
|
||||
Index("ix_issues_project_agent", "project_id", "assigned_agent_id"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Issue {self.id} title='{self.title[:30]}...' "
|
||||
f"status={self.status.value} priority={self.priority.value}>"
|
||||
)
|
||||
88
backend/app/models/syndarix/project.py
Normal file
88
backend/app/models/syndarix/project.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# app/models/syndarix/project.py
|
||||
"""
|
||||
Project model for Syndarix AI consulting platform.
|
||||
|
||||
A Project represents a client engagement where AI agents collaborate
|
||||
to deliver software solutions.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Enum, ForeignKey, Index, String, Text
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
UUID as PGUUID,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import AutonomyLevel, ProjectStatus
|
||||
|
||||
|
||||
class Project(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Project model representing a client engagement.
|
||||
|
||||
A project contains:
|
||||
- Configuration for how autonomous agents should operate
|
||||
- Settings for MCP server integrations
|
||||
- Relationship to assigned agents, issues, and sprints
|
||||
"""
|
||||
|
||||
__tablename__ = "projects"
|
||||
|
||||
name = Column(String(255), nullable=False, index=True)
|
||||
slug = Column(String(255), unique=True, nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
autonomy_level: Column[AutonomyLevel] = Column(
|
||||
Enum(AutonomyLevel),
|
||||
default=AutonomyLevel.MILESTONE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
status: Column[ProjectStatus] = Column(
|
||||
Enum(ProjectStatus),
|
||||
default=ProjectStatus.ACTIVE,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# JSON field for flexible project configuration
|
||||
# Can include: mcp_servers, webhook_urls, notification_settings, etc.
|
||||
settings = Column(JSONB, default=dict, nullable=False)
|
||||
|
||||
# Foreign key to the User who owns this project
|
||||
owner_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("users.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
owner = relationship("User", foreign_keys=[owner_id])
|
||||
agent_instances = relationship(
|
||||
"AgentInstance",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
issues = relationship(
|
||||
"Issue",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
sprints = relationship(
|
||||
"Sprint",
|
||||
back_populates="project",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_projects_slug_status", "slug", "status"),
|
||||
Index("ix_projects_owner_status", "owner_id", "status"),
|
||||
Index("ix_projects_autonomy_status", "autonomy_level", "status"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Project {self.name} ({self.slug}) status={self.status.value}>"
|
||||
74
backend/app/models/syndarix/sprint.py
Normal file
74
backend/app/models/syndarix/sprint.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# app/models/syndarix/sprint.py
|
||||
"""
|
||||
Sprint model for Syndarix AI consulting platform.
|
||||
|
||||
A Sprint represents a time-boxed iteration for organizing and delivering work.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Date, Enum, ForeignKey, Index, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.models.base import Base, TimestampMixin, UUIDMixin
|
||||
|
||||
from .enums import SprintStatus
|
||||
|
||||
|
||||
class Sprint(Base, UUIDMixin, TimestampMixin):
|
||||
"""
|
||||
Sprint model representing a time-boxed iteration.
|
||||
|
||||
Tracks:
|
||||
- Sprint metadata (name, number, goal)
|
||||
- Date range (start/end)
|
||||
- Progress metrics (planned vs completed points)
|
||||
"""
|
||||
|
||||
__tablename__ = "sprints"
|
||||
|
||||
# Foreign key to project
|
||||
project_id = Column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("projects.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Sprint identification
|
||||
name = Column(String(255), nullable=False)
|
||||
number = Column(Integer, nullable=False) # Sprint number within project
|
||||
|
||||
# Sprint goal (what we aim to achieve)
|
||||
goal = Column(Text, nullable=True)
|
||||
|
||||
# Date range
|
||||
start_date = Column(Date, nullable=False, index=True)
|
||||
end_date = Column(Date, nullable=False, index=True)
|
||||
|
||||
# Status
|
||||
status: Column[SprintStatus] = Column(
|
||||
Enum(SprintStatus),
|
||||
default=SprintStatus.PLANNED,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Progress metrics
|
||||
planned_points = Column(Integer, nullable=True) # Sum of story points at start
|
||||
completed_points = Column(Integer, nullable=True) # Sum of completed story points
|
||||
|
||||
# Relationships
|
||||
project = relationship("Project", back_populates="sprints")
|
||||
issues = relationship("Issue", back_populates="sprint")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_sprints_project_status", "project_id", "status"),
|
||||
Index("ix_sprints_project_number", "project_id", "number"),
|
||||
Index("ix_sprints_date_range", "start_date", "end_date"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<Sprint {self.name} (#{self.number}) "
|
||||
f"project={self.project_id} status={self.status.value}>"
|
||||
)
|
||||
275
backend/app/schemas/events.py
Normal file
275
backend/app/schemas/events.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Event schemas for the Syndarix EventBus (Redis Pub/Sub).
|
||||
|
||||
This module defines event types and payload schemas for real-time communication
|
||||
between services, agents, and the frontend.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
"""
|
||||
Event types for the EventBus.
|
||||
|
||||
Naming convention: {domain}.{action}
|
||||
"""
|
||||
|
||||
# Agent Events
|
||||
AGENT_SPAWNED = "agent.spawned"
|
||||
AGENT_STATUS_CHANGED = "agent.status_changed"
|
||||
AGENT_MESSAGE = "agent.message"
|
||||
AGENT_TERMINATED = "agent.terminated"
|
||||
|
||||
# Issue Events
|
||||
ISSUE_CREATED = "issue.created"
|
||||
ISSUE_UPDATED = "issue.updated"
|
||||
ISSUE_ASSIGNED = "issue.assigned"
|
||||
ISSUE_CLOSED = "issue.closed"
|
||||
|
||||
# Sprint Events
|
||||
SPRINT_STARTED = "sprint.started"
|
||||
SPRINT_COMPLETED = "sprint.completed"
|
||||
|
||||
# Approval Events
|
||||
APPROVAL_REQUESTED = "approval.requested"
|
||||
APPROVAL_GRANTED = "approval.granted"
|
||||
APPROVAL_DENIED = "approval.denied"
|
||||
|
||||
# Project Events
|
||||
PROJECT_CREATED = "project.created"
|
||||
PROJECT_UPDATED = "project.updated"
|
||||
PROJECT_ARCHIVED = "project.archived"
|
||||
|
||||
# Workflow Events
|
||||
WORKFLOW_STARTED = "workflow.started"
|
||||
WORKFLOW_STEP_COMPLETED = "workflow.step_completed"
|
||||
WORKFLOW_COMPLETED = "workflow.completed"
|
||||
WORKFLOW_FAILED = "workflow.failed"
|
||||
|
||||
|
||||
ActorType = Literal["agent", "user", "system"]
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
"""
|
||||
Base event schema for the EventBus.
|
||||
|
||||
All events published to the EventBus must conform to this schema.
|
||||
"""
|
||||
|
||||
id: str = Field(
|
||||
...,
|
||||
description="Unique event identifier (UUID string)",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440000"],
|
||||
)
|
||||
type: EventType = Field(
|
||||
...,
|
||||
description="Event type enum value",
|
||||
examples=[EventType.AGENT_MESSAGE],
|
||||
)
|
||||
timestamp: datetime = Field(
|
||||
...,
|
||||
description="When the event occurred (UTC)",
|
||||
examples=["2024-01-15T10:30:00Z"],
|
||||
)
|
||||
project_id: UUID = Field(
|
||||
...,
|
||||
description="Project this event belongs to",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440001"],
|
||||
)
|
||||
actor_id: UUID | None = Field(
|
||||
default=None,
|
||||
description="ID of the agent or user who triggered the event",
|
||||
examples=["550e8400-e29b-41d4-a716-446655440002"],
|
||||
)
|
||||
actor_type: ActorType = Field(
|
||||
...,
|
||||
description="Type of actor: 'agent', 'user', or 'system'",
|
||||
examples=["agent"],
|
||||
)
|
||||
payload: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Event-specific payload data",
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"example": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"type": "agent.message",
|
||||
"timestamp": "2024-01-15T10:30:00Z",
|
||||
"project_id": "550e8400-e29b-41d4-a716-446655440001",
|
||||
"actor_id": "550e8400-e29b-41d4-a716-446655440002",
|
||||
"actor_type": "agent",
|
||||
"payload": {"message": "Processing task...", "progress": 50},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Specific payload schemas for type safety
|
||||
|
||||
|
||||
class AgentSpawnedPayload(BaseModel):
|
||||
"""Payload for AGENT_SPAWNED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the spawned agent instance")
|
||||
agent_type_id: UUID = Field(..., description="ID of the agent type")
|
||||
agent_name: str = Field(..., description="Human-readable name of the agent")
|
||||
role: str = Field(..., description="Agent role (e.g., 'product_owner', 'engineer')")
|
||||
|
||||
|
||||
class AgentStatusChangedPayload(BaseModel):
|
||||
"""Payload for AGENT_STATUS_CHANGED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
previous_status: str = Field(..., description="Previous status")
|
||||
new_status: str = Field(..., description="New status")
|
||||
reason: str | None = Field(default=None, description="Reason for status change")
|
||||
|
||||
|
||||
class AgentMessagePayload(BaseModel):
|
||||
"""Payload for AGENT_MESSAGE events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
message: str = Field(..., description="Message content")
|
||||
message_type: str = Field(
|
||||
default="info",
|
||||
description="Message type: 'info', 'warning', 'error', 'debug'",
|
||||
)
|
||||
metadata: dict = Field(
|
||||
default_factory=dict,
|
||||
description="Additional metadata (e.g., token usage, model info)",
|
||||
)
|
||||
|
||||
|
||||
class AgentTerminatedPayload(BaseModel):
|
||||
"""Payload for AGENT_TERMINATED events."""
|
||||
|
||||
agent_instance_id: UUID = Field(..., description="ID of the agent instance")
|
||||
termination_reason: str = Field(..., description="Reason for termination")
|
||||
final_status: str = Field(..., description="Final status at termination")
|
||||
|
||||
|
||||
class IssueCreatedPayload(BaseModel):
|
||||
"""Payload for ISSUE_CREATED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
title: str = Field(..., description="Issue title")
|
||||
priority: str | None = Field(default=None, description="Issue priority")
|
||||
labels: list[str] = Field(default_factory=list, description="Issue labels")
|
||||
|
||||
|
||||
class IssueUpdatedPayload(BaseModel):
|
||||
"""Payload for ISSUE_UPDATED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
changes: dict = Field(..., description="Dictionary of field changes")
|
||||
|
||||
|
||||
class IssueAssignedPayload(BaseModel):
|
||||
"""Payload for ISSUE_ASSIGNED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
assignee_id: UUID | None = Field(
|
||||
default=None, description="Agent or user assigned to"
|
||||
)
|
||||
assignee_name: str | None = Field(default=None, description="Assignee name")
|
||||
|
||||
|
||||
class IssueClosedPayload(BaseModel):
|
||||
"""Payload for ISSUE_CLOSED events."""
|
||||
|
||||
issue_id: str = Field(..., description="Issue ID (from external tracker)")
|
||||
resolution: str = Field(..., description="Resolution status")
|
||||
|
||||
|
||||
class SprintStartedPayload(BaseModel):
|
||||
"""Payload for SPRINT_STARTED events."""
|
||||
|
||||
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||
sprint_name: str = Field(..., description="Sprint name")
|
||||
goal: str | None = Field(default=None, description="Sprint goal")
|
||||
issue_count: int = Field(default=0, description="Number of issues in sprint")
|
||||
|
||||
|
||||
class SprintCompletedPayload(BaseModel):
|
||||
"""Payload for SPRINT_COMPLETED events."""
|
||||
|
||||
sprint_id: UUID = Field(..., description="Sprint ID")
|
||||
sprint_name: str = Field(..., description="Sprint name")
|
||||
completed_issues: int = Field(default=0, description="Number of completed issues")
|
||||
incomplete_issues: int = Field(
|
||||
default=0, description="Number of incomplete issues"
|
||||
)
|
||||
|
||||
|
||||
class ApprovalRequestedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_REQUESTED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
approval_type: str = Field(..., description="Type of approval needed")
|
||||
description: str = Field(..., description="Description of what needs approval")
|
||||
requested_by: UUID | None = Field(
|
||||
default=None, description="Agent/user requesting approval"
|
||||
)
|
||||
timeout_minutes: int | None = Field(
|
||||
default=None, description="Minutes before auto-escalation"
|
||||
)
|
||||
|
||||
|
||||
class ApprovalGrantedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_GRANTED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
approved_by: UUID = Field(..., description="User who granted approval")
|
||||
comments: str | None = Field(default=None, description="Approval comments")
|
||||
|
||||
|
||||
class ApprovalDeniedPayload(BaseModel):
|
||||
"""Payload for APPROVAL_DENIED events."""
|
||||
|
||||
approval_id: UUID = Field(..., description="Approval request ID")
|
||||
denied_by: UUID = Field(..., description="User who denied approval")
|
||||
reason: str = Field(..., description="Reason for denial")
|
||||
|
||||
|
||||
class WorkflowStartedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_STARTED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
workflow_type: str = Field(..., description="Type of workflow")
|
||||
total_steps: int = Field(default=0, description="Total number of steps")
|
||||
|
||||
|
||||
class WorkflowStepCompletedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_STEP_COMPLETED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
step_name: str = Field(..., description="Name of completed step")
|
||||
step_number: int = Field(..., description="Step number (1-indexed)")
|
||||
total_steps: int = Field(..., description="Total number of steps")
|
||||
result: dict = Field(default_factory=dict, description="Step result data")
|
||||
|
||||
|
||||
class WorkflowCompletedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_COMPLETED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
duration_seconds: float = Field(..., description="Total execution duration")
|
||||
result: dict = Field(default_factory=dict, description="Workflow result data")
|
||||
|
||||
|
||||
class WorkflowFailedPayload(BaseModel):
|
||||
"""Payload for WORKFLOW_FAILED events."""
|
||||
|
||||
workflow_id: UUID = Field(..., description="Workflow execution ID")
|
||||
error_message: str = Field(..., description="Error message")
|
||||
failed_step: str | None = Field(default=None, description="Step that failed")
|
||||
recoverable: bool = Field(default=False, description="Whether error is recoverable")
|
||||
113
backend/app/schemas/syndarix/__init__.py
Normal file
113
backend/app/schemas/syndarix/__init__.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# app/schemas/syndarix/__init__.py
|
||||
"""
|
||||
Syndarix domain schemas.
|
||||
|
||||
This package contains Pydantic schemas for validating and serializing
|
||||
Syndarix domain entities.
|
||||
"""
|
||||
|
||||
from .agent_instance import (
|
||||
AgentInstanceCreate,
|
||||
AgentInstanceInDB,
|
||||
AgentInstanceListResponse,
|
||||
AgentInstanceMetrics,
|
||||
AgentInstanceResponse,
|
||||
AgentInstanceTerminate,
|
||||
AgentInstanceUpdate,
|
||||
)
|
||||
from .agent_type import (
|
||||
AgentTypeCreate,
|
||||
AgentTypeInDB,
|
||||
AgentTypeListResponse,
|
||||
AgentTypeResponse,
|
||||
AgentTypeUpdate,
|
||||
)
|
||||
from .enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
from .issue import (
|
||||
IssueAssign,
|
||||
IssueClose,
|
||||
IssueCreate,
|
||||
IssueInDB,
|
||||
IssueListResponse,
|
||||
IssueResponse,
|
||||
IssueStats,
|
||||
IssueSyncUpdate,
|
||||
IssueUpdate,
|
||||
)
|
||||
from .project import (
|
||||
ProjectCreate,
|
||||
ProjectInDB,
|
||||
ProjectListResponse,
|
||||
ProjectResponse,
|
||||
ProjectUpdate,
|
||||
)
|
||||
from .sprint import (
|
||||
SprintBurndown,
|
||||
SprintComplete,
|
||||
SprintCreate,
|
||||
SprintInDB,
|
||||
SprintListResponse,
|
||||
SprintResponse,
|
||||
SprintStart,
|
||||
SprintUpdate,
|
||||
SprintVelocity,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# AgentInstance schemas
|
||||
"AgentInstanceCreate",
|
||||
"AgentInstanceInDB",
|
||||
"AgentInstanceListResponse",
|
||||
"AgentInstanceMetrics",
|
||||
"AgentInstanceResponse",
|
||||
"AgentInstanceTerminate",
|
||||
"AgentInstanceUpdate",
|
||||
# Enums
|
||||
"AgentStatus",
|
||||
# AgentType schemas
|
||||
"AgentTypeCreate",
|
||||
"AgentTypeInDB",
|
||||
"AgentTypeListResponse",
|
||||
"AgentTypeResponse",
|
||||
"AgentTypeUpdate",
|
||||
"AutonomyLevel",
|
||||
# Issue schemas
|
||||
"IssueAssign",
|
||||
"IssueClose",
|
||||
"IssueCreate",
|
||||
"IssueInDB",
|
||||
"IssueListResponse",
|
||||
"IssuePriority",
|
||||
"IssueResponse",
|
||||
"IssueStats",
|
||||
"IssueStatus",
|
||||
"IssueSyncUpdate",
|
||||
"IssueUpdate",
|
||||
# Project schemas
|
||||
"ProjectCreate",
|
||||
"ProjectInDB",
|
||||
"ProjectListResponse",
|
||||
"ProjectResponse",
|
||||
"ProjectStatus",
|
||||
"ProjectUpdate",
|
||||
# Sprint schemas
|
||||
"SprintBurndown",
|
||||
"SprintComplete",
|
||||
"SprintCreate",
|
||||
"SprintInDB",
|
||||
"SprintListResponse",
|
||||
"SprintResponse",
|
||||
"SprintStart",
|
||||
"SprintStatus",
|
||||
"SprintUpdate",
|
||||
"SprintVelocity",
|
||||
"SyncStatus",
|
||||
]
|
||||
122
backend/app/schemas/syndarix/agent_instance.py
Normal file
122
backend/app/schemas/syndarix/agent_instance.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# app/schemas/syndarix/agent_instance.py
|
||||
"""
|
||||
Pydantic schemas for AgentInstance entity.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from .enums import AgentStatus
|
||||
|
||||
|
||||
class AgentInstanceBase(BaseModel):
|
||||
"""Base agent instance schema with common fields."""
|
||||
|
||||
agent_type_id: UUID
|
||||
project_id: UUID
|
||||
status: AgentStatus = AgentStatus.IDLE
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||
long_term_memory_ref: str | None = Field(None, max_length=500)
|
||||
session_id: str | None = Field(None, max_length=255)
|
||||
|
||||
|
||||
class AgentInstanceCreate(BaseModel):
|
||||
"""Schema for creating a new agent instance."""
|
||||
|
||||
agent_type_id: UUID
|
||||
project_id: UUID
|
||||
status: AgentStatus = AgentStatus.IDLE
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||
long_term_memory_ref: str | None = Field(None, max_length=500)
|
||||
session_id: str | None = Field(None, max_length=255)
|
||||
|
||||
|
||||
class AgentInstanceUpdate(BaseModel):
|
||||
"""Schema for updating an agent instance."""
|
||||
|
||||
status: AgentStatus | None = None
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] | None = None
|
||||
long_term_memory_ref: str | None = None
|
||||
session_id: str | None = None
|
||||
last_activity_at: datetime | None = None
|
||||
tasks_completed: int | None = Field(None, ge=0)
|
||||
tokens_used: int | None = Field(None, ge=0)
|
||||
cost_incurred: Decimal | None = Field(None, ge=0)
|
||||
|
||||
|
||||
class AgentInstanceTerminate(BaseModel):
|
||||
"""Schema for terminating an agent instance."""
|
||||
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class AgentInstanceInDB(AgentInstanceBase):
|
||||
"""Schema for agent instance in database."""
|
||||
|
||||
id: UUID
|
||||
last_activity_at: datetime | None = None
|
||||
terminated_at: datetime | None = None
|
||||
tasks_completed: int = 0
|
||||
tokens_used: int = 0
|
||||
cost_incurred: Decimal = Decimal("0.0000")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentInstanceResponse(BaseModel):
|
||||
"""Schema for agent instance API responses."""
|
||||
|
||||
id: UUID
|
||||
agent_type_id: UUID
|
||||
project_id: UUID
|
||||
status: AgentStatus
|
||||
current_task: str | None = None
|
||||
short_term_memory: dict[str, Any] = Field(default_factory=dict)
|
||||
long_term_memory_ref: str | None = None
|
||||
session_id: str | None = None
|
||||
last_activity_at: datetime | None = None
|
||||
terminated_at: datetime | None = None
|
||||
tasks_completed: int = 0
|
||||
tokens_used: int = 0
|
||||
cost_incurred: Decimal = Decimal("0.0000")
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Expanded fields from relationships
|
||||
agent_type_name: str | None = None
|
||||
agent_type_slug: str | None = None
|
||||
project_name: str | None = None
|
||||
project_slug: str | None = None
|
||||
assigned_issues_count: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentInstanceListResponse(BaseModel):
|
||||
"""Schema for paginated agent instance list responses."""
|
||||
|
||||
agent_instances: list[AgentInstanceResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class AgentInstanceMetrics(BaseModel):
|
||||
"""Schema for agent instance metrics summary."""
|
||||
|
||||
total_instances: int
|
||||
active_instances: int
|
||||
idle_instances: int
|
||||
total_tasks_completed: int
|
||||
total_tokens_used: int
|
||||
total_cost_incurred: Decimal
|
||||
151
backend/app/schemas/syndarix/agent_type.py
Normal file
151
backend/app/schemas/syndarix/agent_type.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# app/schemas/syndarix/agent_type.py
|
||||
"""
|
||||
Pydantic schemas for AgentType entity.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class AgentTypeBase(BaseModel):
|
||||
"""Base agent type schema with common fields."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
expertise: list[str] = Field(default_factory=list)
|
||||
personality_prompt: str = Field(..., min_length=1)
|
||||
primary_model: str = Field(..., min_length=1, max_length=100)
|
||||
fallback_models: list[str] = Field(default_factory=list)
|
||||
model_params: dict[str, Any] = Field(default_factory=dict)
|
||||
mcp_servers: list[str] = Field(default_factory=list)
|
||||
tool_permissions: dict[str, Any] = Field(default_factory=dict)
|
||||
is_active: bool = True
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate agent type name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Agent type name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("expertise")
|
||||
@classmethod
|
||||
def validate_expertise(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize expertise list."""
|
||||
return [e.strip().lower() for e in v if e.strip()]
|
||||
|
||||
@field_validator("mcp_servers")
|
||||
@classmethod
|
||||
def validate_mcp_servers(cls, v: list[str]) -> list[str]:
|
||||
"""Validate MCP server list."""
|
||||
return [s.strip() for s in v if s.strip()]
|
||||
|
||||
|
||||
class AgentTypeCreate(AgentTypeBase):
|
||||
"""Schema for creating a new agent type."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
personality_prompt: str = Field(..., min_length=1)
|
||||
primary_model: str = Field(..., min_length=1, max_length=100)
|
||||
|
||||
|
||||
class AgentTypeUpdate(BaseModel):
|
||||
"""Schema for updating an agent type."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
expertise: list[str] | None = None
|
||||
personality_prompt: str | None = None
|
||||
primary_model: str | None = Field(None, min_length=1, max_length=100)
|
||||
fallback_models: list[str] | None = None
|
||||
model_params: dict[str, Any] | None = None
|
||||
mcp_servers: list[str] | None = None
|
||||
tool_permissions: dict[str, Any] | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate agent type name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Agent type name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
@field_validator("expertise")
|
||||
@classmethod
|
||||
def validate_expertise(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize expertise list."""
|
||||
if v is None:
|
||||
return v
|
||||
return [e.strip().lower() for e in v if e.strip()]
|
||||
|
||||
|
||||
class AgentTypeInDB(AgentTypeBase):
|
||||
"""Schema for agent type in database."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentTypeResponse(AgentTypeBase):
|
||||
"""Schema for agent type API responses."""
|
||||
|
||||
id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
instance_count: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class AgentTypeListResponse(BaseModel):
|
||||
"""Schema for paginated agent type list responses."""
|
||||
|
||||
agent_types: list[AgentTypeResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
26
backend/app/schemas/syndarix/enums.py
Normal file
26
backend/app/schemas/syndarix/enums.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# app/schemas/syndarix/enums.py
|
||||
"""
|
||||
Re-export enums from models for use in schemas.
|
||||
|
||||
This allows schemas to import enums without depending on SQLAlchemy models directly.
|
||||
"""
|
||||
|
||||
from app.models.syndarix.enums import (
|
||||
AgentStatus,
|
||||
AutonomyLevel,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
ProjectStatus,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentStatus",
|
||||
"AutonomyLevel",
|
||||
"IssuePriority",
|
||||
"IssueStatus",
|
||||
"ProjectStatus",
|
||||
"SprintStatus",
|
||||
"SyncStatus",
|
||||
]
|
||||
193
backend/app/schemas/syndarix/issue.py
Normal file
193
backend/app/schemas/syndarix/issue.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# app/schemas/syndarix/issue.py
|
||||
"""
|
||||
Pydantic schemas for Issue entity.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from .enums import IssuePriority, IssueStatus, SyncStatus
|
||||
|
||||
|
||||
class IssueBase(BaseModel):
|
||||
"""Base issue schema with common fields."""
|
||||
|
||||
title: str = Field(..., min_length=1, max_length=500)
|
||||
body: str = ""
|
||||
status: IssueStatus = IssueStatus.OPEN
|
||||
priority: IssuePriority = IssuePriority.MEDIUM
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
story_points: int | None = Field(None, ge=0, le=100)
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def validate_title(cls, v: str) -> str:
|
||||
"""Validate issue title."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Issue title cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("labels")
|
||||
@classmethod
|
||||
def validate_labels(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize labels."""
|
||||
return [label.strip().lower() for label in v if label.strip()]
|
||||
|
||||
|
||||
class IssueCreate(IssueBase):
|
||||
"""Schema for creating a new issue."""
|
||||
|
||||
project_id: UUID
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = Field(None, max_length=255)
|
||||
sprint_id: UUID | None = None
|
||||
|
||||
# External tracker fields (optional, for importing from external systems)
|
||||
external_tracker: Literal["gitea", "github", "gitlab"] | None = None
|
||||
external_id: str | None = Field(None, max_length=255)
|
||||
external_url: str | None = Field(None, max_length=1000)
|
||||
external_number: int | None = None
|
||||
|
||||
|
||||
class IssueUpdate(BaseModel):
|
||||
"""Schema for updating an issue."""
|
||||
|
||||
title: str | None = Field(None, min_length=1, max_length=500)
|
||||
body: str | None = None
|
||||
status: IssueStatus | None = None
|
||||
priority: IssuePriority | None = None
|
||||
labels: list[str] | None = None
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = Field(None, max_length=255)
|
||||
sprint_id: UUID | None = None
|
||||
story_points: int | None = Field(None, ge=0, le=100)
|
||||
sync_status: SyncStatus | None = None
|
||||
|
||||
@field_validator("title")
|
||||
@classmethod
|
||||
def validate_title(cls, v: str | None) -> str | None:
|
||||
"""Validate issue title."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Issue title cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
@field_validator("labels")
|
||||
@classmethod
|
||||
def validate_labels(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize labels."""
|
||||
if v is None:
|
||||
return v
|
||||
return [label.strip().lower() for label in v if label.strip()]
|
||||
|
||||
|
||||
class IssueClose(BaseModel):
|
||||
"""Schema for closing an issue."""
|
||||
|
||||
resolution: str | None = None # Optional resolution note
|
||||
|
||||
|
||||
class IssueAssign(BaseModel):
|
||||
"""Schema for assigning an issue."""
|
||||
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = Field(None, max_length=255)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_assignment(self) -> "IssueAssign":
|
||||
"""Ensure only one type of assignee is set."""
|
||||
if self.assigned_agent_id and self.human_assignee:
|
||||
raise ValueError(
|
||||
"Cannot assign to both an agent and a human. Choose one."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class IssueSyncUpdate(BaseModel):
|
||||
"""Schema for updating sync-related fields."""
|
||||
|
||||
sync_status: SyncStatus
|
||||
last_synced_at: datetime | None = None
|
||||
external_updated_at: datetime | None = None
|
||||
|
||||
|
||||
class IssueInDB(IssueBase):
|
||||
"""Schema for issue in database."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = None
|
||||
sprint_id: UUID | None = None
|
||||
external_tracker: str | None = None
|
||||
external_id: str | None = None
|
||||
external_url: str | None = None
|
||||
external_number: int | None = None
|
||||
sync_status: SyncStatus = SyncStatus.SYNCED
|
||||
last_synced_at: datetime | None = None
|
||||
external_updated_at: datetime | None = None
|
||||
closed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class IssueResponse(BaseModel):
|
||||
"""Schema for issue API responses."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
title: str
|
||||
body: str
|
||||
status: IssueStatus
|
||||
priority: IssuePriority
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
assigned_agent_id: UUID | None = None
|
||||
human_assignee: str | None = None
|
||||
sprint_id: UUID | None = None
|
||||
story_points: int | None = None
|
||||
external_tracker: str | None = None
|
||||
external_id: str | None = None
|
||||
external_url: str | None = None
|
||||
external_number: int | None = None
|
||||
sync_status: SyncStatus = SyncStatus.SYNCED
|
||||
last_synced_at: datetime | None = None
|
||||
external_updated_at: datetime | None = None
|
||||
closed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Expanded fields from relationships
|
||||
project_name: str | None = None
|
||||
project_slug: str | None = None
|
||||
sprint_name: str | None = None
|
||||
assigned_agent_type_name: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class IssueListResponse(BaseModel):
|
||||
"""Schema for paginated issue list responses."""
|
||||
|
||||
issues: list[IssueResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class IssueStats(BaseModel):
|
||||
"""Schema for issue statistics."""
|
||||
|
||||
total: int
|
||||
open: int
|
||||
in_progress: int
|
||||
in_review: int
|
||||
blocked: int
|
||||
closed: int
|
||||
by_priority: dict[str, int]
|
||||
total_story_points: int | None = None
|
||||
completed_story_points: int | None = None
|
||||
127
backend/app/schemas/syndarix/project.py
Normal file
127
backend/app/schemas/syndarix/project.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# app/schemas/syndarix/project.py
|
||||
"""
|
||||
Pydantic schemas for Project entity.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from .enums import AutonomyLevel, ProjectStatus
|
||||
|
||||
|
||||
class ProjectBase(BaseModel):
|
||||
"""Base project schema with common fields."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
autonomy_level: AutonomyLevel = AutonomyLevel.MILESTONE
|
||||
status: ProjectStatus = ProjectStatus.ACTIVE
|
||||
settings: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format: lowercase, alphanumeric, hyphens only."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate project name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Project name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class ProjectCreate(ProjectBase):
|
||||
"""Schema for creating a new project."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
slug: str = Field(..., min_length=1, max_length=255)
|
||||
owner_id: UUID | None = None
|
||||
|
||||
|
||||
class ProjectUpdate(BaseModel):
|
||||
"""Schema for updating a project."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
slug: str | None = Field(None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
autonomy_level: AutonomyLevel | None = None
|
||||
status: ProjectStatus | None = None
|
||||
settings: dict[str, Any] | None = None
|
||||
owner_id: UUID | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
"""Validate slug format."""
|
||||
if v is None:
|
||||
return v
|
||||
if not re.match(r"^[a-z0-9-]+$", v):
|
||||
raise ValueError(
|
||||
"Slug must contain only lowercase letters, numbers, and hyphens"
|
||||
)
|
||||
if v.startswith("-") or v.endswith("-"):
|
||||
raise ValueError("Slug cannot start or end with a hyphen")
|
||||
if "--" in v:
|
||||
raise ValueError("Slug cannot contain consecutive hyphens")
|
||||
return v
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate project name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Project name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class ProjectInDB(ProjectBase):
|
||||
"""Schema for project in database."""
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectResponse(ProjectBase):
|
||||
"""Schema for project API responses."""
|
||||
|
||||
id: UUID
|
||||
owner_id: UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
agent_count: int | None = 0
|
||||
issue_count: int | None = 0
|
||||
active_sprint_name: str | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ProjectListResponse(BaseModel):
|
||||
"""Schema for paginated project list responses."""
|
||||
|
||||
projects: list[ProjectResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
135
backend/app/schemas/syndarix/sprint.py
Normal file
135
backend/app/schemas/syndarix/sprint.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# app/schemas/syndarix/sprint.py
|
||||
"""
|
||||
Pydantic schemas for Sprint entity.
|
||||
"""
|
||||
|
||||
from datetime import date, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
from .enums import SprintStatus
|
||||
|
||||
|
||||
class SprintBase(BaseModel):
|
||||
"""Base sprint schema with common fields."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
number: int = Field(..., ge=1)
|
||||
goal: str | None = None
|
||||
start_date: date
|
||||
end_date: date
|
||||
status: SprintStatus = SprintStatus.PLANNED
|
||||
planned_points: int | None = Field(None, ge=0)
|
||||
completed_points: int | None = Field(None, ge=0)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
"""Validate sprint name."""
|
||||
if not v or v.strip() == "":
|
||||
raise ValueError("Sprint name cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_dates(self) -> "SprintBase":
|
||||
"""Validate that end_date is after start_date."""
|
||||
if self.end_date < self.start_date:
|
||||
raise ValueError("End date must be after or equal to start date")
|
||||
return self
|
||||
|
||||
|
||||
class SprintCreate(SprintBase):
|
||||
"""Schema for creating a new sprint."""
|
||||
|
||||
project_id: UUID
|
||||
|
||||
|
||||
class SprintUpdate(BaseModel):
|
||||
"""Schema for updating a sprint."""
|
||||
|
||||
name: str | None = Field(None, min_length=1, max_length=255)
|
||||
goal: str | None = None
|
||||
start_date: date | None = None
|
||||
end_date: date | None = None
|
||||
status: SprintStatus | None = None
|
||||
planned_points: int | None = Field(None, ge=0)
|
||||
completed_points: int | None = Field(None, ge=0)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, v: str | None) -> str | None:
|
||||
"""Validate sprint name."""
|
||||
if v is not None and (not v or v.strip() == ""):
|
||||
raise ValueError("Sprint name cannot be empty")
|
||||
return v.strip() if v else v
|
||||
|
||||
|
||||
class SprintStart(BaseModel):
|
||||
"""Schema for starting a sprint."""
|
||||
|
||||
start_date: date | None = None # Optionally override start date
|
||||
|
||||
|
||||
class SprintComplete(BaseModel):
|
||||
"""Schema for completing a sprint."""
|
||||
|
||||
completed_points: int | None = Field(None, ge=0)
|
||||
notes: str | None = None
|
||||
|
||||
|
||||
class SprintInDB(SprintBase):
|
||||
"""Schema for sprint in database."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SprintResponse(SprintBase):
|
||||
"""Schema for sprint API responses."""
|
||||
|
||||
id: UUID
|
||||
project_id: UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Expanded fields from relationships
|
||||
project_name: str | None = None
|
||||
project_slug: str | None = None
|
||||
issue_count: int | None = 0
|
||||
open_issues: int | None = 0
|
||||
completed_issues: int | None = 0
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class SprintListResponse(BaseModel):
|
||||
"""Schema for paginated sprint list responses."""
|
||||
|
||||
sprints: list[SprintResponse]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
pages: int
|
||||
|
||||
|
||||
class SprintVelocity(BaseModel):
|
||||
"""Schema for sprint velocity metrics."""
|
||||
|
||||
sprint_number: int
|
||||
sprint_name: str
|
||||
planned_points: int | None
|
||||
completed_points: int | None
|
||||
velocity: float | None # completed/planned ratio
|
||||
|
||||
|
||||
class SprintBurndown(BaseModel):
|
||||
"""Schema for sprint burndown data point."""
|
||||
|
||||
date: date
|
||||
remaining_points: int
|
||||
ideal_remaining: float
|
||||
622
backend/app/services/event_bus.py
Normal file
622
backend/app/services/event_bus.py
Normal file
@@ -0,0 +1,622 @@
|
||||
"""
|
||||
EventBus service for Redis Pub/Sub communication.
|
||||
|
||||
This module provides a centralized event bus for publishing and subscribing to
|
||||
events across the Syndarix platform. It uses Redis Pub/Sub for real-time
|
||||
message delivery between services, agents, and the frontend.
|
||||
|
||||
Architecture:
|
||||
- Publishers emit events to project/agent-specific Redis channels
|
||||
- SSE endpoints subscribe to channels and stream events to clients
|
||||
- Events include metadata for reconnection support (Last-Event-ID)
|
||||
- Events are typed with the EventType enum for consistency
|
||||
|
||||
Usage:
|
||||
# Publishing events
|
||||
event_bus = EventBus()
|
||||
await event_bus.connect()
|
||||
|
||||
event = event_bus.create_event(
|
||||
event_type=EventType.AGENT_MESSAGE,
|
||||
project_id=project_id,
|
||||
actor_type="agent",
|
||||
payload={"message": "Processing..."}
|
||||
)
|
||||
await event_bus.publish(event_bus.get_project_channel(project_id), event)
|
||||
|
||||
# Subscribing to events
|
||||
async for event in event_bus.subscribe(["project:123", "agent:456"]):
|
||||
handle_event(event)
|
||||
|
||||
# Cleanup
|
||||
await event_bus.disconnect()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import redis.asyncio as redis
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.events import ActorType, Event, EventType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventBusError(Exception):
|
||||
"""Base exception for EventBus errors."""
|
||||
|
||||
|
||||
|
||||
class EventBusConnectionError(EventBusError):
|
||||
"""Raised when connection to Redis fails."""
|
||||
|
||||
|
||||
|
||||
class EventBusPublishError(EventBusError):
|
||||
"""Raised when publishing an event fails."""
|
||||
|
||||
|
||||
|
||||
class EventBusSubscriptionError(EventBusError):
|
||||
"""Raised when subscribing to channels fails."""
|
||||
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
EventBus for Redis Pub/Sub communication.
|
||||
|
||||
Provides methods to publish events to channels and subscribe to events
|
||||
from multiple channels. Handles connection management, serialization,
|
||||
and error recovery.
|
||||
|
||||
This class provides:
|
||||
- Event publishing to project/agent-specific channels
|
||||
- Subscription management for SSE endpoints
|
||||
- Reconnection support via event IDs and sequence numbers
|
||||
- Keepalive messages for connection health
|
||||
- Type-safe event creation with the Event schema
|
||||
|
||||
Attributes:
|
||||
redis_url: Redis connection URL
|
||||
redis_client: Async Redis client instance
|
||||
pubsub: Redis PubSub instance for subscriptions
|
||||
"""
|
||||
|
||||
# Channel prefixes for different entity types
|
||||
PROJECT_CHANNEL_PREFIX = "project"
|
||||
AGENT_CHANNEL_PREFIX = "agent"
|
||||
USER_CHANNEL_PREFIX = "user"
|
||||
GLOBAL_CHANNEL = "syndarix:global"
|
||||
|
||||
def __init__(self, redis_url: str | None = None) -> None:
|
||||
"""
|
||||
Initialize the EventBus.
|
||||
|
||||
Args:
|
||||
redis_url: Redis connection URL. Defaults to settings.REDIS_URL.
|
||||
"""
|
||||
self.redis_url = redis_url or settings.REDIS_URL
|
||||
self._redis_client: redis.Redis | None = None
|
||||
self._pubsub: redis.client.PubSub | None = None
|
||||
self._connected = False
|
||||
self._sequence_counters: dict[str, int] = {}
|
||||
|
||||
@property
|
||||
def redis_client(self) -> redis.Redis:
|
||||
"""Get the Redis client, raising if not connected."""
|
||||
if self._redis_client is None:
|
||||
raise EventBusConnectionError(
|
||||
"EventBus not connected. Call connect() first."
|
||||
)
|
||||
return self._redis_client
|
||||
|
||||
@property
|
||||
def pubsub(self) -> redis.client.PubSub:
|
||||
"""Get the PubSub instance, raising if not connected."""
|
||||
if self._pubsub is None:
|
||||
raise EventBusConnectionError(
|
||||
"EventBus not connected. Call connect() first."
|
||||
)
|
||||
return self._pubsub
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the EventBus is connected to Redis."""
|
||||
return self._connected and self._redis_client is not None
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""
|
||||
Connect to Redis and initialize the PubSub client.
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection to Redis fails.
|
||||
"""
|
||||
if self._connected:
|
||||
logger.debug("EventBus already connected")
|
||||
return
|
||||
|
||||
try:
|
||||
self._redis_client = redis.from_url(
|
||||
self.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
# Test connection - ping() returns a coroutine for async Redis
|
||||
ping_result = self._redis_client.ping()
|
||||
if hasattr(ping_result, "__await__"):
|
||||
await ping_result
|
||||
self._pubsub = self._redis_client.pubsub()
|
||||
self._connected = True
|
||||
logger.info("EventBus connected to Redis")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Failed to connect to Redis: {e}", exc_info=True)
|
||||
raise EventBusConnectionError(f"Failed to connect to Redis: {e}") from e
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Redis error during connection: {e}", exc_info=True)
|
||||
raise EventBusConnectionError(f"Redis error: {e}") from e
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnect from Redis and cleanup resources.
|
||||
"""
|
||||
if self._pubsub:
|
||||
try:
|
||||
await self._pubsub.unsubscribe()
|
||||
await self._pubsub.close()
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error closing PubSub: {e}")
|
||||
finally:
|
||||
self._pubsub = None
|
||||
|
||||
if self._redis_client:
|
||||
try:
|
||||
await self._redis_client.aclose()
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error closing Redis client: {e}")
|
||||
finally:
|
||||
self._redis_client = None
|
||||
|
||||
self._connected = False
|
||||
logger.info("EventBus disconnected from Redis")
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self) -> AsyncIterator["EventBus"]:
|
||||
"""
|
||||
Context manager for automatic connection handling.
|
||||
|
||||
Usage:
|
||||
async with event_bus.connection() as bus:
|
||||
await bus.publish(channel, event)
|
||||
"""
|
||||
await self.connect()
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
await self.disconnect()
|
||||
|
||||
def get_project_channel(self, project_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for a project.
|
||||
|
||||
Args:
|
||||
project_id: The project UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "project:{uuid}"
|
||||
"""
|
||||
return f"{self.PROJECT_CHANNEL_PREFIX}:{project_id}"
|
||||
|
||||
def get_agent_channel(self, agent_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for an agent instance.
|
||||
|
||||
Args:
|
||||
agent_id: The agent instance UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "agent:{uuid}"
|
||||
"""
|
||||
return f"{self.AGENT_CHANNEL_PREFIX}:{agent_id}"
|
||||
|
||||
def get_user_channel(self, user_id: UUID | str) -> str:
|
||||
"""
|
||||
Get the channel name for a user (personal notifications).
|
||||
|
||||
Args:
|
||||
user_id: The user UUID or string
|
||||
|
||||
Returns:
|
||||
Channel name string in format "user:{uuid}"
|
||||
"""
|
||||
return f"{self.USER_CHANNEL_PREFIX}:{user_id}"
|
||||
|
||||
def _get_next_sequence(self, channel: str) -> int:
|
||||
"""Get the next sequence number for a channel's events."""
|
||||
current = self._sequence_counters.get(channel, 0)
|
||||
self._sequence_counters[channel] = current + 1
|
||||
return current + 1
|
||||
|
||||
@staticmethod
|
||||
def create_event(
|
||||
event_type: EventType,
|
||||
project_id: UUID,
|
||||
actor_type: ActorType,
|
||||
payload: dict | None = None,
|
||||
actor_id: UUID | None = None,
|
||||
event_id: str | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> Event:
|
||||
"""
|
||||
Factory method to create a new Event.
|
||||
|
||||
Args:
|
||||
event_type: The type of event
|
||||
project_id: The project this event belongs to
|
||||
actor_type: Type of actor ('agent', 'user', or 'system')
|
||||
payload: Event-specific payload data
|
||||
actor_id: ID of the agent or user who triggered the event
|
||||
event_id: Optional custom event ID (UUID string)
|
||||
timestamp: Optional custom timestamp (defaults to now UTC)
|
||||
|
||||
Returns:
|
||||
A new Event instance
|
||||
"""
|
||||
return Event(
|
||||
id=event_id or str(uuid4()),
|
||||
type=event_type,
|
||||
timestamp=timestamp or datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
payload=payload or {},
|
||||
)
|
||||
|
||||
def _serialize_event(self, event: Event) -> str:
|
||||
"""
|
||||
Serialize an event to JSON string.
|
||||
|
||||
Args:
|
||||
event: The Event to serialize
|
||||
|
||||
Returns:
|
||||
JSON string representation of the event
|
||||
"""
|
||||
return event.model_dump_json()
|
||||
|
||||
def _deserialize_event(self, data: str) -> Event:
|
||||
"""
|
||||
Deserialize a JSON string to an Event.
|
||||
|
||||
Args:
|
||||
data: JSON string to deserialize
|
||||
|
||||
Returns:
|
||||
Deserialized Event instance
|
||||
|
||||
Raises:
|
||||
ValidationError: If the data doesn't match the Event schema
|
||||
"""
|
||||
return Event.model_validate_json(data)
|
||||
|
||||
async def publish(self, channel: str, event: Event) -> int:
|
||||
"""
|
||||
Publish an event to a channel.
|
||||
|
||||
Args:
|
||||
channel: The channel name to publish to
|
||||
event: The Event to publish
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If not connected to Redis
|
||||
EventBusPublishError: If publishing fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
try:
|
||||
message = self._serialize_event(event)
|
||||
subscriber_count = await self.redis_client.publish(channel, message)
|
||||
logger.debug(
|
||||
f"Published event {event.type} to {channel} "
|
||||
f"(received by {subscriber_count} subscribers)"
|
||||
)
|
||||
return subscriber_count
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Failed to publish event to {channel}: {e}", exc_info=True)
|
||||
raise EventBusPublishError(f"Failed to publish event: {e}") from e
|
||||
|
||||
async def publish_to_project(self, event: Event) -> int:
|
||||
"""
|
||||
Publish an event to the project's channel.
|
||||
|
||||
Convenience method that publishes to the project channel based on
|
||||
the event's project_id.
|
||||
|
||||
Args:
|
||||
event: The Event to publish (must have project_id set)
|
||||
|
||||
Returns:
|
||||
Number of subscribers that received the message
|
||||
"""
|
||||
channel = self.get_project_channel(event.project_id)
|
||||
return await self.publish(channel, event)
|
||||
|
||||
async def publish_multi(self, channels: list[str], event: Event) -> dict[str, int]:
|
||||
"""
|
||||
Publish an event to multiple channels.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to publish to
|
||||
event: The Event to publish
|
||||
|
||||
Returns:
|
||||
Dictionary mapping channel names to subscriber counts
|
||||
"""
|
||||
results = {}
|
||||
for channel in channels:
|
||||
try:
|
||||
results[channel] = await self.publish(channel, event)
|
||||
except EventBusPublishError as e:
|
||||
logger.warning(f"Failed to publish to {channel}: {e}")
|
||||
results[channel] = 0
|
||||
return results
|
||||
|
||||
async def subscribe(
|
||||
self, channels: list[str], *, max_wait: float | None = None
|
||||
) -> AsyncIterator[Event]:
|
||||
"""
|
||||
Subscribe to one or more channels and yield events.
|
||||
|
||||
This is an async generator that yields Event objects as they arrive.
|
||||
Use max_wait to limit how long to wait for messages.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to subscribe to
|
||||
max_wait: Optional maximum wait time in seconds for each message.
|
||||
If None, waits indefinitely.
|
||||
|
||||
Yields:
|
||||
Event objects received from subscribed channels
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If not connected to Redis
|
||||
EventBusSubscriptionError: If subscription fails
|
||||
|
||||
Example:
|
||||
async for event in event_bus.subscribe(["project:123"], max_wait=30):
|
||||
print(f"Received: {event.type}")
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
# Create a new pubsub for this subscription
|
||||
subscription_pubsub = self.redis_client.pubsub()
|
||||
|
||||
try:
|
||||
await subscription_pubsub.subscribe(*channels)
|
||||
logger.info(f"Subscribed to channels: {channels}")
|
||||
except redis.RedisError as e:
|
||||
logger.error(f"Failed to subscribe to channels: {e}", exc_info=True)
|
||||
await subscription_pubsub.close()
|
||||
raise EventBusSubscriptionError(f"Failed to subscribe: {e}") from e
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if max_wait is not None:
|
||||
async with asyncio.timeout(max_wait):
|
||||
message = await subscription_pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
else:
|
||||
message = await subscription_pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
except TimeoutError:
|
||||
# Timeout reached, stop iteration
|
||||
return
|
||||
|
||||
if message is None:
|
||||
continue
|
||||
|
||||
if message["type"] == "message":
|
||||
try:
|
||||
event = self._deserialize_event(message["data"])
|
||||
yield event
|
||||
except ValidationError as e:
|
||||
logger.warning(
|
||||
f"Invalid event data received: {e}",
|
||||
extra={"channel": message.get("channel")},
|
||||
)
|
||||
continue
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"Failed to decode event JSON: {e}",
|
||||
extra={"channel": message.get("channel")},
|
||||
)
|
||||
continue
|
||||
finally:
|
||||
try:
|
||||
await subscription_pubsub.unsubscribe(*channels)
|
||||
await subscription_pubsub.close()
|
||||
logger.debug(f"Unsubscribed from channels: {channels}")
|
||||
except redis.RedisError as e:
|
||||
logger.warning(f"Error unsubscribing from channels: {e}")
|
||||
|
||||
async def subscribe_sse(
|
||||
self,
|
||||
project_id: str | UUID,
|
||||
last_event_id: str | None = None,
|
||||
keepalive_interval: int = 30,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Subscribe to events for a project in SSE format.
|
||||
|
||||
This is an async generator that yields SSE-formatted event strings.
|
||||
It includes keepalive messages at the specified interval.
|
||||
|
||||
Args:
|
||||
project_id: The project to subscribe to
|
||||
last_event_id: Optional last received event ID for reconnection
|
||||
keepalive_interval: Seconds between keepalive messages (default 30)
|
||||
|
||||
Yields:
|
||||
SSE-formatted event strings (ready to send to client)
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise EventBusConnectionError("EventBus not connected")
|
||||
|
||||
project_id_str = str(project_id)
|
||||
channel = self.get_project_channel(project_id_str)
|
||||
|
||||
subscription_pubsub = self.redis_client.pubsub()
|
||||
await subscription_pubsub.subscribe(channel)
|
||||
|
||||
logger.info(
|
||||
f"Subscribed to SSE events for project {project_id_str} "
|
||||
f"(last_event_id={last_event_id})"
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for messages with a timeout for keepalive
|
||||
message = await asyncio.wait_for(
|
||||
subscription_pubsub.get_message(ignore_subscribe_messages=True),
|
||||
timeout=keepalive_interval,
|
||||
)
|
||||
|
||||
if message is not None and message["type"] == "message":
|
||||
event_data = message["data"]
|
||||
|
||||
# If reconnecting, check if we should skip this event
|
||||
if last_event_id:
|
||||
try:
|
||||
event_dict = json.loads(event_data)
|
||||
if event_dict.get("id") == last_event_id:
|
||||
# Found the last event, start yielding from next
|
||||
last_event_id = None
|
||||
continue
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
yield event_data
|
||||
|
||||
except TimeoutError:
|
||||
# Send keepalive comment
|
||||
yield "" # Empty string signals keepalive
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"SSE subscription cancelled for project {project_id_str}")
|
||||
raise
|
||||
finally:
|
||||
await subscription_pubsub.unsubscribe(channel)
|
||||
await subscription_pubsub.close()
|
||||
logger.info(f"Unsubscribed SSE from project {project_id_str}")
|
||||
|
||||
async def subscribe_with_callback(
|
||||
self,
|
||||
channels: list[str],
|
||||
callback: Any, # Callable[[Event], Awaitable[None]]
|
||||
stop_event: asyncio.Event | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to channels and process events with a callback.
|
||||
|
||||
This method runs until stop_event is set or an unrecoverable error occurs.
|
||||
|
||||
Args:
|
||||
channels: List of channel names to subscribe to
|
||||
callback: Async function to call for each event
|
||||
stop_event: Optional asyncio.Event to signal stop
|
||||
|
||||
Example:
|
||||
async def handle_event(event: Event):
|
||||
print(f"Handling: {event.type}")
|
||||
|
||||
stop = asyncio.Event()
|
||||
asyncio.create_task(
|
||||
event_bus.subscribe_with_callback(["project:123"], handle_event, stop)
|
||||
)
|
||||
# Later...
|
||||
stop.set()
|
||||
"""
|
||||
if stop_event is None:
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
try:
|
||||
async for event in self.subscribe(channels):
|
||||
if stop_event.is_set():
|
||||
break
|
||||
try:
|
||||
await callback(event)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in event callback: {e}", exc_info=True)
|
||||
except EventBusSubscriptionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in subscription loop: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Singleton instance for application-wide use
|
||||
_event_bus: EventBus | None = None
|
||||
|
||||
|
||||
def get_event_bus() -> EventBus:
|
||||
"""
|
||||
Get the singleton EventBus instance.
|
||||
|
||||
Creates a new instance if one doesn't exist. Note that you still need
|
||||
to call connect() before using the EventBus.
|
||||
|
||||
Returns:
|
||||
The singleton EventBus instance
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is None:
|
||||
_event_bus = EventBus()
|
||||
return _event_bus
|
||||
|
||||
|
||||
async def get_connected_event_bus() -> EventBus:
|
||||
"""
|
||||
Get a connected EventBus instance.
|
||||
|
||||
Ensures the EventBus is connected before returning. For use in
|
||||
FastAPI dependency injection.
|
||||
|
||||
Returns:
|
||||
A connected EventBus instance
|
||||
|
||||
Raises:
|
||||
EventBusConnectionError: If connection fails
|
||||
"""
|
||||
event_bus = get_event_bus()
|
||||
if not event_bus.is_connected:
|
||||
await event_bus.connect()
|
||||
return event_bus
|
||||
|
||||
|
||||
async def close_event_bus() -> None:
|
||||
"""
|
||||
Close the global EventBus instance.
|
||||
|
||||
Should be called during application shutdown.
|
||||
"""
|
||||
global _event_bus
|
||||
if _event_bus is not None:
|
||||
await _event_bus.disconnect()
|
||||
_event_bus = None
|
||||
23
backend/app/tasks/__init__.py
Normal file
23
backend/app/tasks/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# app/tasks/__init__.py
|
||||
"""
|
||||
Celery background tasks for Syndarix.
|
||||
|
||||
This package contains all Celery tasks organized by domain:
|
||||
|
||||
Modules:
|
||||
agent: Agent execution tasks (run_agent_step, spawn_agent, terminate_agent)
|
||||
git: Git operation tasks (clone, commit, branch, push, PR)
|
||||
sync: Issue synchronization tasks (incremental/full sync, webhooks)
|
||||
workflow: Workflow state management tasks
|
||||
cost: Cost tracking and budget monitoring tasks
|
||||
"""
|
||||
|
||||
from app.tasks import agent, cost, git, sync, workflow
|
||||
|
||||
__all__ = [
|
||||
"agent",
|
||||
"cost",
|
||||
"git",
|
||||
"sync",
|
||||
"workflow",
|
||||
]
|
||||
150
backend/app/tasks/agent.py
Normal file
150
backend/app/tasks/agent.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# app/tasks/agent.py
|
||||
"""
|
||||
Agent execution tasks for Syndarix.
|
||||
|
||||
These tasks handle the lifecycle of AI agent instances:
|
||||
- Spawning new agent instances from agent types
|
||||
- Executing agent steps (LLM calls, tool execution)
|
||||
- Terminating agent instances
|
||||
|
||||
Tasks are routed to the 'agent' queue for dedicated processing.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.agent.run_agent_step")
|
||||
def run_agent_step(
|
||||
self,
|
||||
agent_instance_id: str,
|
||||
context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a single step of an agent's workflow.
|
||||
|
||||
This task performs one iteration of the agent loop:
|
||||
1. Load agent instance state
|
||||
2. Call LLM with context and available tools
|
||||
3. Execute tool calls if any
|
||||
4. Update agent state
|
||||
5. Return result for next step or completion
|
||||
|
||||
Args:
|
||||
agent_instance_id: UUID of the agent instance
|
||||
context: Current execution context including:
|
||||
- messages: Conversation history
|
||||
- tools: Available tool definitions
|
||||
- state: Agent state data
|
||||
- metadata: Project/task metadata
|
||||
|
||||
Returns:
|
||||
dict with status and agent_instance_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Running agent step for instance {agent_instance_id} with context keys: {list(context.keys())}"
|
||||
)
|
||||
|
||||
# TODO: Implement actual agent step execution
|
||||
# This will involve:
|
||||
# 1. Loading agent instance from database
|
||||
# 2. Calling LLM provider (via litellm or anthropic SDK)
|
||||
# 3. Processing tool calls through MCP servers
|
||||
# 4. Updating agent state in database
|
||||
# 5. Scheduling next step if needed
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"agent_instance_id": agent_instance_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.agent.spawn_agent")
|
||||
def spawn_agent(
|
||||
self,
|
||||
agent_type_id: str,
|
||||
project_id: str,
|
||||
initial_context: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Spawn a new agent instance from an agent type.
|
||||
|
||||
This task creates a new agent instance:
|
||||
1. Load agent type configuration (model, expertise, personality)
|
||||
2. Create agent instance record in database
|
||||
3. Initialize agent state with project context
|
||||
4. Start first agent step
|
||||
|
||||
Args:
|
||||
agent_type_id: UUID of the agent type template
|
||||
project_id: UUID of the project this agent will work on
|
||||
initial_context: Starting context including:
|
||||
- goal: High-level objective
|
||||
- constraints: Any limitations or requirements
|
||||
- assigned_issues: Issues to work on
|
||||
- autonomy_level: FULL_CONTROL, MILESTONE, or AUTONOMOUS
|
||||
|
||||
Returns:
|
||||
dict with status, agent_type_id, and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Spawning agent of type {agent_type_id} for project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement agent spawning
|
||||
# This will involve:
|
||||
# 1. Loading agent type from database
|
||||
# 2. Creating agent instance record
|
||||
# 3. Setting up MCP tool access
|
||||
# 4. Initializing agent state
|
||||
# 5. Kicking off first step
|
||||
|
||||
return {
|
||||
"status": "spawned",
|
||||
"agent_type_id": agent_type_id,
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.agent.terminate_agent")
|
||||
def terminate_agent(
|
||||
self,
|
||||
agent_instance_id: str,
|
||||
reason: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Terminate an agent instance.
|
||||
|
||||
This task gracefully shuts down an agent:
|
||||
1. Mark agent instance as terminated
|
||||
2. Save final state for audit
|
||||
3. Release any held resources
|
||||
4. Notify relevant subscribers
|
||||
|
||||
Args:
|
||||
agent_instance_id: UUID of the agent instance
|
||||
reason: Reason for termination (completion, error, manual, budget)
|
||||
|
||||
Returns:
|
||||
dict with status and agent_instance_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Terminating agent instance {agent_instance_id} with reason: {reason}"
|
||||
)
|
||||
|
||||
# TODO: Implement agent termination
|
||||
# This will involve:
|
||||
# 1. Loading agent instance
|
||||
# 2. Updating status to terminated
|
||||
# 3. Saving termination reason
|
||||
# 4. Cleaning up any pending tasks
|
||||
# 5. Sending termination event
|
||||
|
||||
return {
|
||||
"status": "terminated",
|
||||
"agent_instance_id": agent_instance_id,
|
||||
}
|
||||
201
backend/app/tasks/cost.py
Normal file
201
backend/app/tasks/cost.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# app/tasks/cost.py
|
||||
"""
|
||||
Cost tracking and budget management tasks for Syndarix.
|
||||
|
||||
These tasks implement multi-layered cost tracking per ADR-012:
|
||||
- Per-agent token usage tracking
|
||||
- Project budget monitoring
|
||||
- Daily cost aggregation
|
||||
- Budget threshold alerts
|
||||
- Cost reporting
|
||||
|
||||
Costs are tracked in real-time in Redis for speed,
|
||||
then aggregated to PostgreSQL for durability.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.cost.aggregate_daily_costs")
|
||||
def aggregate_daily_costs(self) -> dict[str, Any]:
|
||||
"""
|
||||
Aggregate daily costs from Redis to PostgreSQL.
|
||||
|
||||
This periodic task (runs daily):
|
||||
1. Read accumulated costs from Redis
|
||||
2. Aggregate by project, agent, and model
|
||||
3. Store in PostgreSQL cost_records table
|
||||
4. Clear Redis counters for new day
|
||||
|
||||
Returns:
|
||||
dict with status
|
||||
"""
|
||||
logger.info("Starting daily cost aggregation")
|
||||
|
||||
# TODO: Implement cost aggregation
|
||||
# This will involve:
|
||||
# 1. Fetching cost data from Redis
|
||||
# 2. Grouping by project_id, agent_id, model
|
||||
# 3. Inserting into PostgreSQL cost tables
|
||||
# 4. Resetting Redis counters
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.cost.check_budget_thresholds")
|
||||
def check_budget_thresholds(
|
||||
self,
|
||||
project_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Check if a project has exceeded budget thresholds.
|
||||
|
||||
This task checks budget limits:
|
||||
1. Get current spend from Redis counters
|
||||
2. Compare against project budget limits
|
||||
3. Send alerts if thresholds exceeded
|
||||
4. Pause agents if hard limit reached
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(f"Checking budget thresholds for project {project_id}")
|
||||
|
||||
# TODO: Implement budget checking
|
||||
# This will involve:
|
||||
# 1. Loading project budget configuration
|
||||
# 2. Getting current spend from Redis
|
||||
# 3. Comparing against soft/hard limits
|
||||
# 4. Sending alerts or pausing agents
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.cost.record_llm_usage")
|
||||
def record_llm_usage(
|
||||
self,
|
||||
agent_id: str,
|
||||
project_id: str,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cost_usd: float,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Record LLM usage from an agent call.
|
||||
|
||||
This task tracks each LLM API call:
|
||||
1. Increment Redis counters for real-time tracking
|
||||
2. Store raw usage event for audit
|
||||
3. Trigger budget check if threshold approaching
|
||||
|
||||
Args:
|
||||
agent_id: UUID of the agent instance
|
||||
project_id: UUID of the project
|
||||
model: Model identifier (e.g., claude-opus-4-5-20251101)
|
||||
prompt_tokens: Number of input tokens
|
||||
completion_tokens: Number of output tokens
|
||||
cost_usd: Calculated cost in USD
|
||||
|
||||
Returns:
|
||||
dict with status, agent_id, project_id, and cost_usd
|
||||
"""
|
||||
logger.debug(
|
||||
f"Recording LLM usage for model {model}: "
|
||||
f"{prompt_tokens} prompt + {completion_tokens} completion tokens = ${cost_usd}"
|
||||
)
|
||||
|
||||
# TODO: Implement usage recording
|
||||
# This will involve:
|
||||
# 1. Incrementing Redis counters
|
||||
# 2. Storing usage event
|
||||
# 3. Checking if near budget threshold
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"agent_id": agent_id,
|
||||
"project_id": project_id,
|
||||
"cost_usd": cost_usd,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.cost.generate_cost_report")
|
||||
def generate_cost_report(
|
||||
self,
|
||||
project_id: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate a cost report for a project.
|
||||
|
||||
This task creates a detailed cost breakdown:
|
||||
1. Query cost records for date range
|
||||
2. Group by agent, model, and day
|
||||
3. Calculate totals and trends
|
||||
4. Format report for display
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
start_date: Report start date (YYYY-MM-DD)
|
||||
end_date: Report end date (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
dict with status, project_id, and date range
|
||||
"""
|
||||
logger.info(
|
||||
f"Generating cost report for project {project_id} from {start_date} to {end_date}"
|
||||
)
|
||||
|
||||
# TODO: Implement report generation
|
||||
# This will involve:
|
||||
# 1. Querying PostgreSQL for cost records
|
||||
# 2. Aggregating by various dimensions
|
||||
# 3. Calculating totals and averages
|
||||
# 4. Formatting report data
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.cost.reset_daily_budget_counters")
|
||||
def reset_daily_budget_counters(self) -> dict[str, Any]:
|
||||
"""
|
||||
Reset daily budget counters in Redis.
|
||||
|
||||
This periodic task (runs daily at midnight UTC):
|
||||
1. Archive current day's counters
|
||||
2. Reset all daily budget counters
|
||||
3. Prepare for new day's tracking
|
||||
|
||||
Returns:
|
||||
dict with status
|
||||
"""
|
||||
logger.info("Resetting daily budget counters")
|
||||
|
||||
# TODO: Implement counter reset
|
||||
# This will involve:
|
||||
# 1. Getting all daily counter keys from Redis
|
||||
# 2. Archiving current values
|
||||
# 3. Resetting counters to zero
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
}
|
||||
225
backend/app/tasks/git.py
Normal file
225
backend/app/tasks/git.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# app/tasks/git.py
|
||||
"""
|
||||
Git operation tasks for Syndarix.
|
||||
|
||||
These tasks handle Git operations for projects:
|
||||
- Cloning repositories
|
||||
- Creating branches
|
||||
- Committing changes
|
||||
- Pushing to remotes
|
||||
- Creating pull requests
|
||||
|
||||
Tasks are routed to the 'git' queue for dedicated processing.
|
||||
All operations are scoped by project_id for multi-tenancy.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.git.clone_repository")
|
||||
def clone_repository(
|
||||
self,
|
||||
project_id: str,
|
||||
repo_url: str,
|
||||
branch: str = "main",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Clone a repository for a project.
|
||||
|
||||
This task clones a Git repository to the project workspace:
|
||||
1. Prepare workspace directory
|
||||
2. Clone repository with credentials
|
||||
3. Checkout specified branch
|
||||
4. Update project metadata
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
repo_url: Git repository URL (HTTPS or SSH)
|
||||
branch: Branch to checkout (default: main)
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Cloning repository {repo_url} for project {project_id} on branch {branch}"
|
||||
)
|
||||
|
||||
# TODO: Implement repository cloning
|
||||
# This will involve:
|
||||
# 1. Getting project credentials from secrets store
|
||||
# 2. Creating workspace directory
|
||||
# 3. Running git clone with proper auth
|
||||
# 4. Checking out the target branch
|
||||
# 5. Updating project record with clone status
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.git.commit_changes")
|
||||
def commit_changes(
|
||||
self,
|
||||
project_id: str,
|
||||
message: str,
|
||||
files: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Commit changes in a project repository.
|
||||
|
||||
This task creates a Git commit:
|
||||
1. Stage specified files (or all if None)
|
||||
2. Create commit with message
|
||||
3. Update commit history record
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
message: Commit message (follows conventional commits)
|
||||
files: List of files to stage, or None for all staged
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Committing changes for project {project_id}: {message}"
|
||||
)
|
||||
|
||||
# TODO: Implement commit operation
|
||||
# This will involve:
|
||||
# 1. Loading project workspace path
|
||||
# 2. Running git add for specified files
|
||||
# 3. Running git commit with message
|
||||
# 4. Recording commit hash in database
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.git.create_branch")
|
||||
def create_branch(
|
||||
self,
|
||||
project_id: str,
|
||||
branch_name: str,
|
||||
from_ref: str = "HEAD",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a new branch in a project repository.
|
||||
|
||||
This task creates a Git branch:
|
||||
1. Checkout from reference
|
||||
2. Create new branch
|
||||
3. Update branch tracking
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
branch_name: Name of the new branch (e.g., feature/123-description)
|
||||
from_ref: Reference to branch from (default: HEAD)
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating branch {branch_name} from {from_ref} for project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement branch creation
|
||||
# This will involve:
|
||||
# 1. Loading project workspace
|
||||
# 2. Running git checkout -b from_ref
|
||||
# 3. Recording branch in database
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.git.create_pull_request")
|
||||
def create_pull_request(
|
||||
self,
|
||||
project_id: str,
|
||||
title: str,
|
||||
body: str,
|
||||
head_branch: str,
|
||||
base_branch: str = "main",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a pull request for a project.
|
||||
|
||||
This task creates a PR on the external Git provider:
|
||||
1. Push branch if needed
|
||||
2. Create PR via API (Gitea, GitHub, GitLab)
|
||||
3. Store PR reference
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
title: PR title
|
||||
body: PR description (markdown)
|
||||
head_branch: Branch with changes
|
||||
base_branch: Target branch (default: main)
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Creating PR '{title}' from {head_branch} to {base_branch} for project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement PR creation
|
||||
# This will involve:
|
||||
# 1. Loading project and Git provider config
|
||||
# 2. Ensuring head_branch is pushed
|
||||
# 3. Calling provider API to create PR
|
||||
# 4. Storing PR URL and number
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.git.push_changes")
|
||||
def push_changes(
|
||||
self,
|
||||
project_id: str,
|
||||
branch: str,
|
||||
force: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Push changes to remote repository.
|
||||
|
||||
This task pushes commits to the remote:
|
||||
1. Verify authentication
|
||||
2. Push branch to remote
|
||||
3. Handle push failures
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
branch: Branch to push
|
||||
force: Whether to force push (use with caution)
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Pushing branch {branch} for project {project_id} (force={force})"
|
||||
)
|
||||
|
||||
# TODO: Implement push operation
|
||||
# This will involve:
|
||||
# 1. Loading project credentials
|
||||
# 2. Running git push (with --force if specified)
|
||||
# 3. Handling authentication and conflicts
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
198
backend/app/tasks/sync.py
Normal file
198
backend/app/tasks/sync.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# app/tasks/sync.py
|
||||
"""
|
||||
Issue synchronization tasks for Syndarix.
|
||||
|
||||
These tasks handle bidirectional issue synchronization:
|
||||
- Incremental sync (polling for recent changes)
|
||||
- Full reconciliation (daily comprehensive sync)
|
||||
- Webhook event processing
|
||||
- Pushing local changes to external trackers
|
||||
|
||||
Tasks are routed to the 'sync' queue for dedicated processing.
|
||||
Per ADR-011, sync follows a master/replica model with configurable direction.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.sync.sync_issues_incremental")
|
||||
def sync_issues_incremental(self) -> dict[str, Any]:
|
||||
"""
|
||||
Perform incremental issue synchronization across all projects.
|
||||
|
||||
This periodic task (runs every 5 minutes):
|
||||
1. Query each project's external tracker for recent changes
|
||||
2. Compare with local issue cache
|
||||
3. Apply updates to local database
|
||||
4. Handle conflicts based on sync direction config
|
||||
|
||||
Returns:
|
||||
dict with status and type
|
||||
"""
|
||||
logger.info("Starting incremental issue sync across all projects")
|
||||
|
||||
# TODO: Implement incremental sync
|
||||
# This will involve:
|
||||
# 1. Loading all active projects with sync enabled
|
||||
# 2. For each project, querying external tracker since last_sync_at
|
||||
# 3. Upserting issues into local database
|
||||
# 4. Updating last_sync_at timestamp
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"type": "incremental",
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.sync.sync_issues_full")
|
||||
def sync_issues_full(self) -> dict[str, Any]:
|
||||
"""
|
||||
Perform full issue reconciliation across all projects.
|
||||
|
||||
This periodic task (runs daily):
|
||||
1. Fetch all issues from external trackers
|
||||
2. Compare with local database
|
||||
3. Handle orphaned issues
|
||||
4. Resolve any drift between systems
|
||||
|
||||
Returns:
|
||||
dict with status and type
|
||||
"""
|
||||
logger.info("Starting full issue reconciliation across all projects")
|
||||
|
||||
# TODO: Implement full sync
|
||||
# This will involve:
|
||||
# 1. Loading all active projects
|
||||
# 2. Fetching complete issue lists from external trackers
|
||||
# 3. Comparing with local database
|
||||
# 4. Handling deletes and orphans
|
||||
# 5. Resolving conflicts based on sync config
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"type": "full",
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.sync.process_webhook_event")
|
||||
def process_webhook_event(
|
||||
self,
|
||||
provider: str,
|
||||
event_type: str,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Process a webhook event from an external Git provider.
|
||||
|
||||
This task handles real-time updates from:
|
||||
- Gitea: issue.created, issue.updated, pull_request.*, etc.
|
||||
- GitHub: issues, pull_request, push, etc.
|
||||
- GitLab: issue events, merge request events, etc.
|
||||
|
||||
Args:
|
||||
provider: Git provider name (gitea, github, gitlab)
|
||||
event_type: Event type from provider
|
||||
payload: Raw webhook payload
|
||||
|
||||
Returns:
|
||||
dict with status, provider, and event_type
|
||||
"""
|
||||
logger.info(f"Processing webhook event from {provider}: {event_type}")
|
||||
|
||||
# TODO: Implement webhook processing
|
||||
# This will involve:
|
||||
# 1. Validating webhook signature
|
||||
# 2. Parsing provider-specific payload
|
||||
# 3. Mapping to internal event format
|
||||
# 4. Updating local database
|
||||
# 5. Triggering any dependent workflows
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"provider": provider,
|
||||
"event_type": event_type,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.sync.sync_project_issues")
|
||||
def sync_project_issues(
|
||||
self,
|
||||
project_id: str,
|
||||
full: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Synchronize issues for a specific project.
|
||||
|
||||
This task can be triggered manually or by webhooks:
|
||||
1. Connect to project's external tracker
|
||||
2. Fetch issues (incremental or full)
|
||||
3. Update local database
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
full: Whether to do full sync or incremental
|
||||
|
||||
Returns:
|
||||
dict with status and project_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Syncing issues for project {project_id} (full={full})"
|
||||
)
|
||||
|
||||
# TODO: Implement project-specific sync
|
||||
# This will involve:
|
||||
# 1. Loading project configuration
|
||||
# 2. Connecting to external tracker
|
||||
# 3. Fetching issues based on full flag
|
||||
# 4. Upserting to database
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.sync.push_issue_to_external")
|
||||
def push_issue_to_external(
|
||||
self,
|
||||
project_id: str,
|
||||
issue_id: str,
|
||||
operation: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Push a local issue change to the external tracker.
|
||||
|
||||
This task handles outbound sync when Syndarix is the master:
|
||||
- create: Create new issue in external tracker
|
||||
- update: Update existing issue
|
||||
- close: Close issue in external tracker
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
issue_id: UUID of the local issue
|
||||
operation: Operation type (create, update, close)
|
||||
|
||||
Returns:
|
||||
dict with status, issue_id, and operation
|
||||
"""
|
||||
logger.info(
|
||||
f"Pushing {operation} for issue {issue_id} in project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement outbound sync
|
||||
# This will involve:
|
||||
# 1. Loading issue and project config
|
||||
# 2. Mapping to external tracker format
|
||||
# 3. Calling provider API
|
||||
# 4. Updating external_id mapping
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"issue_id": issue_id,
|
||||
"operation": operation,
|
||||
}
|
||||
213
backend/app/tasks/workflow.py
Normal file
213
backend/app/tasks/workflow.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# app/tasks/workflow.py
|
||||
"""
|
||||
Workflow state management tasks for Syndarix.
|
||||
|
||||
These tasks manage workflow execution and state transitions:
|
||||
- Sprint workflows (planning -> implementation -> review -> done)
|
||||
- Story workflows (todo -> in_progress -> review -> done)
|
||||
- Approval checkpoints for autonomy levels
|
||||
- Stale workflow recovery
|
||||
|
||||
Per ADR-007 and ADR-010, workflow state is durable in PostgreSQL
|
||||
with defined state transitions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.workflow.recover_stale_workflows")
|
||||
def recover_stale_workflows(self) -> dict[str, Any]:
|
||||
"""
|
||||
Recover workflows that have become stale.
|
||||
|
||||
This periodic task (runs every 5 minutes):
|
||||
1. Find workflows stuck in intermediate states
|
||||
2. Check for timed-out agent operations
|
||||
3. Retry or escalate based on configuration
|
||||
4. Notify relevant users if needed
|
||||
|
||||
Returns:
|
||||
dict with status and recovered count
|
||||
"""
|
||||
logger.info("Checking for stale workflows to recover")
|
||||
|
||||
# TODO: Implement stale workflow recovery
|
||||
# This will involve:
|
||||
# 1. Querying for workflows with last_updated > threshold
|
||||
# 2. Checking if associated agents are still running
|
||||
# 3. Retrying or resetting stuck workflows
|
||||
# 4. Sending notifications for manual intervention
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"recovered": 0,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.workflow.execute_workflow_step")
|
||||
def execute_workflow_step(
|
||||
self,
|
||||
workflow_id: str,
|
||||
transition: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Execute a state transition for a workflow.
|
||||
|
||||
This task applies a transition to a workflow:
|
||||
1. Validate transition is allowed from current state
|
||||
2. Execute any pre-transition hooks
|
||||
3. Update workflow state
|
||||
4. Execute any post-transition hooks
|
||||
5. Trigger follow-up tasks
|
||||
|
||||
Args:
|
||||
workflow_id: UUID of the workflow
|
||||
transition: Transition to execute (start, approve, reject, etc.)
|
||||
|
||||
Returns:
|
||||
dict with status, workflow_id, and transition
|
||||
"""
|
||||
logger.info(
|
||||
f"Executing transition '{transition}' for workflow {workflow_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement workflow transition
|
||||
# This will involve:
|
||||
# 1. Loading workflow from database
|
||||
# 2. Validating transition from current state
|
||||
# 3. Running pre-transition hooks
|
||||
# 4. Updating state in database
|
||||
# 5. Running post-transition hooks
|
||||
# 6. Scheduling follow-up tasks
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"workflow_id": workflow_id,
|
||||
"transition": transition,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.workflow.handle_approval_response")
|
||||
def handle_approval_response(
|
||||
self,
|
||||
workflow_id: str,
|
||||
approved: bool,
|
||||
comment: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Handle a user approval response for a workflow checkpoint.
|
||||
|
||||
This task processes approval decisions:
|
||||
1. Record approval decision with timestamp
|
||||
2. Update workflow state accordingly
|
||||
3. Resume or halt workflow execution
|
||||
4. Notify relevant parties
|
||||
|
||||
Args:
|
||||
workflow_id: UUID of the workflow
|
||||
approved: Whether the checkpoint was approved
|
||||
comment: Optional comment from approver
|
||||
|
||||
Returns:
|
||||
dict with status, workflow_id, and approved flag
|
||||
"""
|
||||
logger.info(
|
||||
f"Handling approval response for workflow {workflow_id}: approved={approved}"
|
||||
)
|
||||
|
||||
# TODO: Implement approval handling
|
||||
# This will involve:
|
||||
# 1. Loading workflow and approval checkpoint
|
||||
# 2. Recording decision with user and timestamp
|
||||
# 3. Transitioning workflow state
|
||||
# 4. Resuming or stopping execution
|
||||
# 5. Sending notifications
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"workflow_id": workflow_id,
|
||||
"approved": approved,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.workflow.start_sprint_workflow")
|
||||
def start_sprint_workflow(
|
||||
self,
|
||||
project_id: str,
|
||||
sprint_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Start a new sprint workflow.
|
||||
|
||||
This task initializes sprint execution:
|
||||
1. Create sprint workflow record
|
||||
2. Set up sprint planning phase
|
||||
3. Spawn Product Owner agent for planning
|
||||
4. Begin story assignment
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
sprint_id: UUID of the sprint
|
||||
|
||||
Returns:
|
||||
dict with status and sprint_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting sprint workflow for sprint {sprint_id} in project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement sprint workflow initialization
|
||||
# This will involve:
|
||||
# 1. Creating workflow record for sprint
|
||||
# 2. Setting initial state to PLANNING
|
||||
# 3. Spawning PO agent for sprint planning
|
||||
# 4. Setting up monitoring and checkpoints
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"sprint_id": sprint_id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="app.tasks.workflow.start_story_workflow")
|
||||
def start_story_workflow(
|
||||
self,
|
||||
project_id: str,
|
||||
story_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Start a new story workflow.
|
||||
|
||||
This task initializes story execution:
|
||||
1. Create story workflow record
|
||||
2. Spawn appropriate developer agent
|
||||
3. Set up implementation tracking
|
||||
4. Configure approval checkpoints based on autonomy level
|
||||
|
||||
Args:
|
||||
project_id: UUID of the project
|
||||
story_id: UUID of the story/issue
|
||||
|
||||
Returns:
|
||||
dict with status and story_id
|
||||
"""
|
||||
logger.info(
|
||||
f"Starting story workflow for story {story_id} in project {project_id}"
|
||||
)
|
||||
|
||||
# TODO: Implement story workflow initialization
|
||||
# This will involve:
|
||||
# 1. Creating workflow record for story
|
||||
# 2. Determining appropriate agent type
|
||||
# 3. Spawning developer agent
|
||||
# 4. Setting up checkpoints based on autonomy level
|
||||
|
||||
return {
|
||||
"status": "pending",
|
||||
"story_id": story_id,
|
||||
}
|
||||
@@ -22,41 +22,37 @@ dependencies = [
|
||||
"pydantic-settings>=2.2.1",
|
||||
"python-multipart>=0.0.19",
|
||||
"fastapi-utils==0.8.0",
|
||||
|
||||
# Database
|
||||
"sqlalchemy>=2.0.29",
|
||||
"alembic>=1.14.1",
|
||||
"psycopg2-binary>=2.9.9",
|
||||
"asyncpg>=0.29.0",
|
||||
"aiosqlite==0.21.0",
|
||||
|
||||
# Environment configuration
|
||||
"python-dotenv>=1.0.1",
|
||||
|
||||
# API utilities
|
||||
"email-validator>=2.1.0.post1",
|
||||
"ujson>=5.9.0",
|
||||
|
||||
# CORS and security
|
||||
"starlette>=0.40.0",
|
||||
"starlette-csrf>=1.4.5",
|
||||
"slowapi>=0.1.9",
|
||||
|
||||
# Utilities
|
||||
"httpx>=0.27.0",
|
||||
"tenacity>=8.2.3",
|
||||
"pytz>=2024.1",
|
||||
"pillow>=10.3.0",
|
||||
"apscheduler==3.11.0",
|
||||
|
||||
# Security and authentication (pinned for reproducibility)
|
||||
"python-jose==3.4.0",
|
||||
"passlib==1.7.4",
|
||||
"bcrypt==4.2.1",
|
||||
"cryptography==44.0.1",
|
||||
|
||||
# OAuth authentication
|
||||
"authlib>=1.3.0",
|
||||
# Celery for background task processing (Syndarix agent jobs)
|
||||
"celery[redis]>=5.4.0",
|
||||
"sse-starlette>=3.1.1",
|
||||
]
|
||||
|
||||
# Development dependencies
|
||||
|
||||
525
backend/tests/api/routes/test_events.py
Normal file
525
backend/tests/api/routes/test_events.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""
|
||||
Tests for the SSE events endpoint.
|
||||
|
||||
This module tests the Server-Sent Events endpoint for project event streaming,
|
||||
including:
|
||||
- Authentication and authorization
|
||||
- SSE stream connection and format
|
||||
- Keepalive mechanism
|
||||
- Reconnection support (Last-Event-ID)
|
||||
- Connection cleanup
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi import status
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.api.dependencies.event_bus import get_event_bus
|
||||
from app.core.database import get_db
|
||||
from app.main import app
|
||||
from app.schemas.events import Event, EventType
|
||||
from app.services.event_bus import EventBus
|
||||
|
||||
|
||||
class MockEventBus:
|
||||
"""Mock EventBus for testing without Redis."""
|
||||
|
||||
def __init__(self):
|
||||
self.published_events: list[Event] = []
|
||||
self._should_yield_events = True
|
||||
self._events_to_yield: list[str] = []
|
||||
self._connected = True
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connected
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._connected = True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._connected = False
|
||||
|
||||
def get_project_channel(self, project_id: uuid.UUID | str) -> str:
|
||||
"""Get the channel name for a project."""
|
||||
return f"project:{project_id}"
|
||||
|
||||
@staticmethod
|
||||
def create_event(
|
||||
event_type: EventType,
|
||||
project_id: uuid.UUID,
|
||||
actor_type: str,
|
||||
payload: dict | None = None,
|
||||
actor_id: uuid.UUID | None = None,
|
||||
event_id: str | None = None,
|
||||
timestamp: datetime | None = None,
|
||||
) -> Event:
|
||||
"""Create a new Event."""
|
||||
return Event(
|
||||
id=event_id or str(uuid.uuid4()),
|
||||
type=event_type,
|
||||
timestamp=timestamp or datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_id=actor_id,
|
||||
actor_type=actor_type,
|
||||
payload=payload or {},
|
||||
)
|
||||
|
||||
async def publish(self, channel: str, event: Event) -> int:
|
||||
"""Publish an event to a channel."""
|
||||
self.published_events.append(event)
|
||||
return 1
|
||||
|
||||
def add_event_to_yield(self, event_json: str) -> None:
|
||||
"""Add an event JSON string to be yielded by subscribe_sse."""
|
||||
self._events_to_yield.append(event_json)
|
||||
|
||||
async def subscribe_sse(
|
||||
self,
|
||||
project_id: str | uuid.UUID,
|
||||
last_event_id: str | None = None,
|
||||
keepalive_interval: int = 30,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Mock subscribe_sse that yields pre-configured events then keepalive."""
|
||||
# First yield any pre-configured events
|
||||
for event_data in self._events_to_yield:
|
||||
yield event_data
|
||||
|
||||
# Then yield keepalive
|
||||
yield ""
|
||||
|
||||
# Then stop to allow test to complete
|
||||
self._should_yield_events = False
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_event_bus():
|
||||
"""Create a mock event bus for testing."""
|
||||
return MockEventBus()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client_with_mock_bus(async_test_db, mock_event_bus):
|
||||
"""
|
||||
Create a FastAPI test client with mocked database and event bus.
|
||||
"""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async def override_get_db():
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def override_get_event_bus():
|
||||
return mock_event_bus
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_event_bus] = override_get_event_bus
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as test_client:
|
||||
yield test_client
|
||||
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def user_token_with_mock_bus(client_with_mock_bus, async_test_user):
|
||||
"""Create an access token for the test user."""
|
||||
response = await client_with_mock_bus.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": async_test_user.email,
|
||||
"password": "TestPassword123!",
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200, f"Login failed: {response.text}"
|
||||
tokens = response.json()
|
||||
return tokens["access_token"]
|
||||
|
||||
|
||||
class TestSSEEndpointAuthentication:
|
||||
"""Tests for SSE endpoint authentication."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_requires_authentication(self, client_with_mock_bus):
|
||||
"""Test that SSE endpoint requires authentication."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_with_invalid_token(self, client_with_mock_bus):
|
||||
"""Test that SSE endpoint rejects invalid tokens."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
headers={"Authorization": "Bearer invalid_token"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class TestSSEEndpointStream:
|
||||
"""Tests for SSE stream functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_returns_sse_response(
|
||||
self, client_with_mock_bus, user_token_with_mock_bus
|
||||
):
|
||||
"""Test that SSE endpoint returns proper SSE response."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
# Make request with a timeout to avoid hanging
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
# The response should start streaming
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert "text/event-stream" in response.headers.get("content-type", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_with_events(
|
||||
self, client_with_mock_bus, user_token_with_mock_bus, mock_event_bus
|
||||
):
|
||||
"""Test that SSE endpoint yields events."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
# Create a test event and add it to the mock bus
|
||||
test_event = Event(
|
||||
id=str(uuid.uuid4()),
|
||||
type=EventType.AGENT_MESSAGE,
|
||||
timestamp=datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_type="agent",
|
||||
payload={"message": "test"},
|
||||
)
|
||||
mock_event_bus.add_event_to_yield(test_event.model_dump_json())
|
||||
|
||||
# Request the stream
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
# Check response contains event data
|
||||
content = response.text
|
||||
assert "agent.message" in content or "data:" in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_with_last_event_id(
|
||||
self, client_with_mock_bus, user_token_with_mock_bus
|
||||
):
|
||||
"""Test that Last-Event-ID header is accepted."""
|
||||
project_id = uuid.uuid4()
|
||||
last_event_id = str(uuid.uuid4())
|
||||
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
headers={
|
||||
"Authorization": f"Bearer {user_token_with_mock_bus}",
|
||||
"Last-Event-ID": last_event_id,
|
||||
},
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
# Should accept the header and return OK
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
|
||||
class TestSSEEndpointHeaders:
|
||||
"""Tests for SSE response headers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_events_cache_control_header(
|
||||
self, client_with_mock_bus, user_token_with_mock_bus
|
||||
):
|
||||
"""Test that SSE response has no-cache header."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
response = await client_with_mock_bus.get(
|
||||
f"/api/v1/projects/{project_id}/events/stream",
|
||||
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
cache_control = response.headers.get("cache-control", "")
|
||||
assert "no-cache" in cache_control.lower()
|
||||
|
||||
|
||||
class TestTestEventEndpoint:
|
||||
"""Tests for the test event endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_test_event_requires_auth(self, client_with_mock_bus):
|
||||
"""Test that test event endpoint requires authentication."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
response = await client_with_mock_bus.post(
|
||||
f"/api/v1/projects/{project_id}/events/test",
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_test_event_success(
|
||||
self, client_with_mock_bus, user_token_with_mock_bus, mock_event_bus
|
||||
):
|
||||
"""Test sending a test event."""
|
||||
project_id = uuid.uuid4()
|
||||
|
||||
response = await client_with_mock_bus.post(
|
||||
f"/api/v1/projects/{project_id}/events/test",
|
||||
headers={"Authorization": f"Bearer {user_token_with_mock_bus}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "event_id" in data
|
||||
assert data["event_type"] == "agent.message"
|
||||
|
||||
# Verify event was published
|
||||
assert len(mock_event_bus.published_events) == 1
|
||||
published = mock_event_bus.published_events[0]
|
||||
assert published.type == EventType.AGENT_MESSAGE
|
||||
assert published.project_id == project_id
|
||||
|
||||
|
||||
class TestEventSchema:
|
||||
"""Tests for the Event schema."""
|
||||
|
||||
def test_event_creation(self):
|
||||
"""Test Event creation with required fields."""
|
||||
project_id = uuid.uuid4()
|
||||
event = Event(
|
||||
id=str(uuid.uuid4()),
|
||||
type=EventType.AGENT_MESSAGE,
|
||||
timestamp=datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_type="agent",
|
||||
payload={"message": "test"},
|
||||
)
|
||||
|
||||
assert event.id is not None
|
||||
assert event.type == EventType.AGENT_MESSAGE
|
||||
assert event.project_id == project_id
|
||||
assert event.actor_type == "agent"
|
||||
assert event.payload == {"message": "test"}
|
||||
|
||||
def test_event_json_serialization(self):
|
||||
"""Test Event JSON serialization."""
|
||||
project_id = uuid.uuid4()
|
||||
event = Event(
|
||||
id="test-id",
|
||||
type=EventType.AGENT_STATUS_CHANGED,
|
||||
timestamp=datetime.now(UTC),
|
||||
project_id=project_id,
|
||||
actor_type="system",
|
||||
payload={"status": "running"},
|
||||
)
|
||||
|
||||
json_str = event.model_dump_json()
|
||||
parsed = json.loads(json_str)
|
||||
|
||||
assert parsed["id"] == "test-id"
|
||||
assert parsed["type"] == "agent.status_changed"
|
||||
assert str(parsed["project_id"]) == str(project_id)
|
||||
assert parsed["payload"]["status"] == "running"
|
||||
|
||||
|
||||
class TestEventBusUnit:
|
||||
"""Unit tests for EventBus class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_not_connected_raises(self):
|
||||
"""Test that accessing redis_client before connect raises."""
|
||||
from app.services.event_bus import EventBusConnectionError
|
||||
|
||||
bus = EventBus()
|
||||
|
||||
with pytest.raises(EventBusConnectionError, match="not connected"):
|
||||
_ = bus.redis_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_channel_names(self):
|
||||
"""Test channel name generation."""
|
||||
bus = EventBus()
|
||||
project_id = uuid.uuid4()
|
||||
agent_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
|
||||
assert bus.get_project_channel(project_id) == f"project:{project_id}"
|
||||
assert bus.get_agent_channel(agent_id) == f"agent:{agent_id}"
|
||||
assert bus.get_user_channel(user_id) == f"user:{user_id}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_sequence_counter(self):
|
||||
"""Test sequence counter increments."""
|
||||
bus = EventBus()
|
||||
channel = "test-channel"
|
||||
|
||||
seq1 = bus._get_next_sequence(channel)
|
||||
seq2 = bus._get_next_sequence(channel)
|
||||
seq3 = bus._get_next_sequence(channel)
|
||||
|
||||
assert seq1 == 1
|
||||
assert seq2 == 2
|
||||
assert seq3 == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_sequence_per_channel(self):
|
||||
"""Test sequence counter is per-channel."""
|
||||
bus = EventBus()
|
||||
|
||||
seq1 = bus._get_next_sequence("channel-1")
|
||||
seq2 = bus._get_next_sequence("channel-2")
|
||||
seq3 = bus._get_next_sequence("channel-1")
|
||||
|
||||
assert seq1 == 1
|
||||
assert seq2 == 1 # Different channel starts at 1
|
||||
assert seq3 == 2
|
||||
|
||||
def test_event_bus_create_event(self):
|
||||
"""Test EventBus.create_event factory method."""
|
||||
project_id = uuid.uuid4()
|
||||
actor_id = uuid.uuid4()
|
||||
|
||||
event = EventBus.create_event(
|
||||
event_type=EventType.ISSUE_CREATED,
|
||||
project_id=project_id,
|
||||
actor_type="user",
|
||||
actor_id=actor_id,
|
||||
payload={"title": "Test Issue"},
|
||||
)
|
||||
|
||||
assert event.type == EventType.ISSUE_CREATED
|
||||
assert event.project_id == project_id
|
||||
assert event.actor_id == actor_id
|
||||
assert event.actor_type == "user"
|
||||
assert event.payload == {"title": "Test Issue"}
|
||||
|
||||
|
||||
class TestEventBusIntegration:
|
||||
"""Integration tests for EventBus with mocked Redis."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_connect_disconnect(self):
|
||||
"""Test EventBus connect and disconnect."""
|
||||
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_client.ping = AsyncMock()
|
||||
mock_client.pubsub = lambda: AsyncMock()
|
||||
|
||||
bus = EventBus(redis_url="redis://localhost:6379/0")
|
||||
|
||||
# Connect
|
||||
await bus.connect()
|
||||
mock_client.ping.assert_called_once()
|
||||
assert bus._redis_client is not None
|
||||
assert bus.is_connected
|
||||
|
||||
# Disconnect
|
||||
await bus.disconnect()
|
||||
mock_client.aclose.assert_called_once()
|
||||
assert bus._redis_client is None
|
||||
assert not bus.is_connected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_publish(self):
|
||||
"""Test EventBus event publishing."""
|
||||
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_client.ping = AsyncMock()
|
||||
mock_client.publish = AsyncMock(return_value=1)
|
||||
mock_client.pubsub = lambda: AsyncMock()
|
||||
|
||||
bus = EventBus()
|
||||
await bus.connect()
|
||||
|
||||
project_id = uuid.uuid4()
|
||||
event = EventBus.create_event(
|
||||
event_type=EventType.AGENT_SPAWNED,
|
||||
project_id=project_id,
|
||||
actor_type="system",
|
||||
payload={"agent_name": "test-agent"},
|
||||
)
|
||||
|
||||
channel = bus.get_project_channel(project_id)
|
||||
result = await bus.publish(channel, event)
|
||||
|
||||
# Verify publish was called
|
||||
mock_client.publish.assert_called_once()
|
||||
call_args = mock_client.publish.call_args
|
||||
|
||||
# Check channel name
|
||||
assert call_args[0][0] == f"project:{project_id}"
|
||||
|
||||
# Check result
|
||||
assert result == 1
|
||||
|
||||
await bus.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_connect_failure(self):
|
||||
"""Test EventBus handles connection failure."""
|
||||
from app.services.event_bus import EventBusConnectionError
|
||||
|
||||
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
|
||||
import redis.asyncio as redis_async
|
||||
|
||||
mock_client.ping = AsyncMock(
|
||||
side_effect=redis_async.ConnectionError("Connection refused")
|
||||
)
|
||||
|
||||
bus = EventBus()
|
||||
|
||||
with pytest.raises(EventBusConnectionError, match="Failed to connect"):
|
||||
await bus.connect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_bus_already_connected(self):
|
||||
"""Test EventBus connect when already connected is a no-op."""
|
||||
with patch("app.services.event_bus.redis.from_url") as mock_redis:
|
||||
mock_client = AsyncMock()
|
||||
mock_redis.return_value = mock_client
|
||||
mock_client.ping = AsyncMock()
|
||||
mock_client.pubsub = lambda: AsyncMock()
|
||||
|
||||
bus = EventBus()
|
||||
|
||||
# First connect
|
||||
await bus.connect()
|
||||
assert mock_client.ping.call_count == 1
|
||||
|
||||
# Second connect should be a no-op
|
||||
await bus.connect()
|
||||
assert mock_client.ping.call_count == 1
|
||||
|
||||
await bus.disconnect()
|
||||
784
backend/tests/core/test_redis.py
Normal file
784
backend/tests/core/test_redis.py
Normal file
@@ -0,0 +1,784 @@
|
||||
"""
|
||||
Tests for Redis client utility functions (app/core/redis.py).
|
||||
|
||||
Covers:
|
||||
- Cache operations (get, set, delete, expire)
|
||||
- JSON serialization helpers
|
||||
- Pub/sub operations
|
||||
- Health check
|
||||
- Connection pooling
|
||||
- Error handling
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import ConnectionError, RedisError, TimeoutError
|
||||
|
||||
from app.core.redis import (
|
||||
DEFAULT_CACHE_TTL,
|
||||
POOL_MAX_CONNECTIONS,
|
||||
RedisClient,
|
||||
check_redis_health,
|
||||
close_redis,
|
||||
get_redis,
|
||||
redis_client,
|
||||
)
|
||||
|
||||
|
||||
class TestRedisClientInit:
|
||||
"""Test RedisClient initialization."""
|
||||
|
||||
def test_default_url_from_settings(self):
|
||||
"""Test that default URL comes from settings."""
|
||||
with patch("app.core.redis.settings") as mock_settings:
|
||||
mock_settings.REDIS_URL = "redis://test:6379/0"
|
||||
client = RedisClient()
|
||||
assert client._url == "redis://test:6379/0"
|
||||
|
||||
def test_custom_url_override(self):
|
||||
"""Test that custom URL overrides settings."""
|
||||
client = RedisClient(url="redis://custom:6379/1")
|
||||
assert client._url == "redis://custom:6379/1"
|
||||
|
||||
def test_initial_state(self):
|
||||
"""Test initial client state."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
assert client._pool is None
|
||||
assert client._client is None
|
||||
|
||||
|
||||
class TestCacheOperations:
|
||||
"""Test cache get/set/delete operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_success(self):
|
||||
"""Test setting a cache value."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=True)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_set("test-key", "test-value", ttl=60)
|
||||
|
||||
assert result is True
|
||||
mock_redis.set.assert_called_once_with("test-key", "test-value", ex=60)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_default_ttl(self):
|
||||
"""Test setting a cache value with default TTL."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=True)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_set("test-key", "test-value")
|
||||
|
||||
assert result is True
|
||||
mock_redis.set.assert_called_once_with(
|
||||
"test-key", "test-value", ex=DEFAULT_CACHE_TTL
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_connection_error(self):
|
||||
"""Test cache_set handles connection errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_set("test-key", "test-value")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_timeout_error(self):
|
||||
"""Test cache_set handles timeout errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(side_effect=TimeoutError("Timeout"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_set("test-key", "test-value")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_redis_error(self):
|
||||
"""Test cache_set handles generic Redis errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(side_effect=RedisError("Unknown error"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_set("test-key", "test-value")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_success(self):
|
||||
"""Test getting a cached value."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value="cached-value")
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get("test-key")
|
||||
|
||||
assert result == "cached-value"
|
||||
mock_redis.get.assert_called_once_with("test-key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_miss(self):
|
||||
"""Test cache miss returns None."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get("nonexistent-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_connection_error(self):
|
||||
"""Test cache_get handles connection errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get("test-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_success(self):
|
||||
"""Test deleting a cache key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=1)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete("test-key")
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once_with("test-key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_nonexistent_key(self):
|
||||
"""Test deleting a nonexistent key returns False."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=0)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete("nonexistent-key")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_connection_error(self):
|
||||
"""Test cache_delete handles connection errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete("test-key")
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCacheDeletePattern:
|
||||
"""Test cache_delete_pattern operation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_pattern_success(self):
|
||||
"""Test deleting keys by pattern."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.delete = AsyncMock(return_value=1)
|
||||
|
||||
# Create async iterator for scan_iter
|
||||
async def mock_scan_iter(pattern):
|
||||
for key in ["user:1", "user:2", "user:3"]:
|
||||
yield key
|
||||
|
||||
mock_redis.scan_iter = mock_scan_iter
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete_pattern("user:*")
|
||||
|
||||
assert result == 3
|
||||
assert mock_redis.delete.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_pattern_no_matches(self):
|
||||
"""Test deleting pattern with no matches."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
async def mock_scan_iter(pattern):
|
||||
if False: # Empty iterator
|
||||
yield
|
||||
|
||||
mock_redis.scan_iter = mock_scan_iter
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete_pattern("nonexistent:*")
|
||||
|
||||
assert result == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_delete_pattern_error(self):
|
||||
"""Test cache_delete_pattern handles errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
|
||||
async def mock_scan_iter(pattern):
|
||||
raise ConnectionError("Connection lost")
|
||||
if False: # Make it a generator
|
||||
yield
|
||||
|
||||
mock_redis.scan_iter = mock_scan_iter
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_delete_pattern("user:*")
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
class TestCacheExpire:
|
||||
"""Test cache_expire operation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expire_success(self):
|
||||
"""Test setting TTL on existing key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.expire = AsyncMock(return_value=True)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_expire("test-key", 120)
|
||||
|
||||
assert result is True
|
||||
mock_redis.expire.assert_called_once_with("test-key", 120)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expire_nonexistent_key(self):
|
||||
"""Test setting TTL on nonexistent key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.expire = AsyncMock(return_value=False)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_expire("nonexistent-key", 120)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_expire_error(self):
|
||||
"""Test cache_expire handles errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.expire = AsyncMock(side_effect=ConnectionError("Error"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_expire("test-key", 120)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCacheHelpers:
|
||||
"""Test cache helper methods (exists, ttl)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_exists_true(self):
|
||||
"""Test cache_exists returns True for existing key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.exists = AsyncMock(return_value=1)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_exists("test-key")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_exists_false(self):
|
||||
"""Test cache_exists returns False for nonexistent key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.exists = AsyncMock(return_value=0)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_exists("nonexistent-key")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_exists_error(self):
|
||||
"""Test cache_exists handles errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.exists = AsyncMock(side_effect=ConnectionError("Error"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_exists("test-key")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_ttl_with_ttl(self):
|
||||
"""Test cache_ttl returns remaining TTL."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ttl = AsyncMock(return_value=300)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_ttl("test-key")
|
||||
|
||||
assert result == 300
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_ttl_no_ttl(self):
|
||||
"""Test cache_ttl returns -1 for key without TTL."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ttl = AsyncMock(return_value=-1)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_ttl("test-key")
|
||||
|
||||
assert result == -1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_ttl_nonexistent_key(self):
|
||||
"""Test cache_ttl returns -2 for nonexistent key."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ttl = AsyncMock(return_value=-2)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_ttl("nonexistent-key")
|
||||
|
||||
assert result == -2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_ttl_error(self):
|
||||
"""Test cache_ttl handles errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ttl = AsyncMock(side_effect=ConnectionError("Error"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_ttl("test-key")
|
||||
|
||||
assert result == -2
|
||||
|
||||
|
||||
class TestJsonOperations:
|
||||
"""Test JSON serialization cache operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_json_success(self):
|
||||
"""Test setting a JSON value in cache."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.set = AsyncMock(return_value=True)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
data = {"user": "test", "count": 42}
|
||||
result = await client.cache_set_json("test-key", data, ttl=60)
|
||||
|
||||
assert result is True
|
||||
mock_redis.set.assert_called_once()
|
||||
# Verify JSON was serialized
|
||||
call_args = mock_redis.set.call_args
|
||||
assert call_args[0][1] == json.dumps(data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_set_json_serialization_error(self):
|
||||
"""Test cache_set_json handles serialization errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
# Object that can't be serialized
|
||||
class NonSerializable:
|
||||
pass
|
||||
|
||||
result = await client.cache_set_json("test-key", NonSerializable())
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_json_success(self):
|
||||
"""Test getting a JSON value from cache."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
data = {"user": "test", "count": 42}
|
||||
mock_redis.get = AsyncMock(return_value=json.dumps(data))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get_json("test-key")
|
||||
|
||||
assert result == data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_json_miss(self):
|
||||
"""Test cache_get_json returns None on cache miss."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get_json("nonexistent-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_get_json_invalid_json(self):
|
||||
"""Test cache_get_json handles invalid JSON."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value="not valid json {{{")
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.cache_get_json("test-key")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestPubSubOperations:
|
||||
"""Test pub/sub operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_string_message(self):
|
||||
"""Test publishing a string message."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock(return_value=2)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.publish("test-channel", "hello world")
|
||||
|
||||
assert result == 2
|
||||
mock_redis.publish.assert_called_once_with("test-channel", "hello world")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_dict_message(self):
|
||||
"""Test publishing a dict message (JSON serialized)."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock(return_value=1)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
data = {"event": "user_created", "user_id": 123}
|
||||
result = await client.publish("events", data)
|
||||
|
||||
assert result == 1
|
||||
mock_redis.publish.assert_called_once_with("events", json.dumps(data))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_connection_error(self):
|
||||
"""Test publish handles connection errors."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.publish = AsyncMock(side_effect=ConnectionError("Connection lost"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.publish("test-channel", "hello")
|
||||
|
||||
assert result == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subscribe_context_manager(self):
|
||||
"""Test subscribe context manager."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_pubsub = AsyncMock()
|
||||
mock_pubsub.subscribe = AsyncMock()
|
||||
mock_pubsub.unsubscribe = AsyncMock()
|
||||
mock_pubsub.close = AsyncMock()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
async with client.subscribe("channel1", "channel2") as pubsub:
|
||||
assert pubsub is mock_pubsub
|
||||
mock_pubsub.subscribe.assert_called_once_with("channel1", "channel2")
|
||||
|
||||
# After exiting context, should unsubscribe and close
|
||||
mock_pubsub.unsubscribe.assert_called_once_with("channel1", "channel2")
|
||||
mock_pubsub.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_psubscribe_context_manager(self):
|
||||
"""Test pattern subscribe context manager."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_pubsub = AsyncMock()
|
||||
mock_pubsub.psubscribe = AsyncMock()
|
||||
mock_pubsub.punsubscribe = AsyncMock()
|
||||
mock_pubsub.close = AsyncMock()
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.pubsub = MagicMock(return_value=mock_pubsub)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
async with client.psubscribe("user:*", "event:*") as pubsub:
|
||||
assert pubsub is mock_pubsub
|
||||
mock_pubsub.psubscribe.assert_called_once_with("user:*", "event:*")
|
||||
|
||||
mock_pubsub.punsubscribe.assert_called_once_with("user:*", "event:*")
|
||||
mock_pubsub.close.assert_called_once()
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Test health check functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(self):
|
||||
"""Test health check returns True when Redis is healthy."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(return_value=True)
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.health_check()
|
||||
|
||||
assert result is True
|
||||
mock_redis.ping.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_connection_error(self):
|
||||
"""Test health check returns False on connection error."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(side_effect=ConnectionError("Connection refused"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_timeout_error(self):
|
||||
"""Test health check returns False on timeout."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(side_effect=TimeoutError("Timeout"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_redis_error(self):
|
||||
"""Test health check returns False on Redis error."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.ping = AsyncMock(side_effect=RedisError("Unknown error"))
|
||||
|
||||
with patch.object(client, "_get_client", return_value=mock_redis):
|
||||
result = await client.health_check()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestConnectionPooling:
|
||||
"""Test connection pooling functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_initialization(self):
|
||||
"""Test that pool is lazily initialized."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
assert client._pool is None
|
||||
|
||||
with patch("app.core.redis.ConnectionPool") as MockPool:
|
||||
mock_pool = MagicMock()
|
||||
MockPool.from_url = MagicMock(return_value=mock_pool)
|
||||
|
||||
pool = await client._ensure_pool()
|
||||
|
||||
assert pool is mock_pool
|
||||
MockPool.from_url.assert_called_once_with(
|
||||
"redis://localhost:6379/0",
|
||||
max_connections=POOL_MAX_CONNECTIONS,
|
||||
socket_timeout=10,
|
||||
socket_connect_timeout=10,
|
||||
decode_responses=True,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pool_reuses_existing(self):
|
||||
"""Test that pool is reused after initialization."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_pool = MagicMock()
|
||||
client._pool = mock_pool
|
||||
|
||||
pool = await client._ensure_pool()
|
||||
|
||||
assert pool is mock_pool
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_disposes_resources(self):
|
||||
"""Test that close() disposes pool and client."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_pool = AsyncMock()
|
||||
mock_pool.disconnect = AsyncMock()
|
||||
|
||||
client._client = mock_client
|
||||
client._pool = mock_pool
|
||||
|
||||
await client.close()
|
||||
|
||||
mock_client.close.assert_called_once()
|
||||
mock_pool.disconnect.assert_called_once()
|
||||
assert client._client is None
|
||||
assert client._pool is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_handles_none(self):
|
||||
"""Test that close() handles None client and pool gracefully."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
# Should not raise
|
||||
await client.close()
|
||||
|
||||
assert client._client is None
|
||||
assert client._pool is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pool_info_not_initialized(self):
|
||||
"""Test pool info when not initialized."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
info = await client.get_pool_info()
|
||||
|
||||
assert info == {"status": "not_initialized"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pool_info_active(self):
|
||||
"""Test pool info when active."""
|
||||
client = RedisClient(url="redis://user:pass@localhost:6379/0")
|
||||
|
||||
mock_pool = MagicMock()
|
||||
client._pool = mock_pool
|
||||
|
||||
info = await client.get_pool_info()
|
||||
|
||||
assert info["status"] == "active"
|
||||
assert info["max_connections"] == POOL_MAX_CONNECTIONS
|
||||
# Password should be hidden
|
||||
assert "pass" not in info["url"]
|
||||
assert "localhost:6379/0" in info["url"]
|
||||
|
||||
|
||||
class TestModuleLevelFunctions:
|
||||
"""Test module-level convenience functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_redis_dependency(self):
|
||||
"""Test get_redis FastAPI dependency."""
|
||||
redis_gen = get_redis()
|
||||
|
||||
client = await redis_gen.__anext__()
|
||||
assert client is redis_client
|
||||
|
||||
# Cleanup
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await redis_gen.__anext__()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_redis_health(self):
|
||||
"""Test module-level check_redis_health function."""
|
||||
with patch.object(redis_client, "health_check", return_value=True) as mock:
|
||||
result = await check_redis_health()
|
||||
|
||||
assert result is True
|
||||
mock.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_redis(self):
|
||||
"""Test module-level close_redis function."""
|
||||
with patch.object(redis_client, "close") as mock:
|
||||
await close_redis()
|
||||
|
||||
mock.assert_called_once()
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Test thread-safety of pool initialization."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_pool_initialization(self):
|
||||
"""Test that concurrent _ensure_pool calls create only one pool."""
|
||||
client = RedisClient(url="redis://localhost:6379/0")
|
||||
|
||||
call_count = 0
|
||||
mock_pool = MagicMock()
|
||||
|
||||
def counting_from_url(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return mock_pool
|
||||
|
||||
with patch("app.core.redis.ConnectionPool") as MockPool:
|
||||
MockPool.from_url = MagicMock(side_effect=counting_from_url)
|
||||
|
||||
# Start multiple concurrent _ensure_pool calls
|
||||
results = await asyncio.gather(
|
||||
client._ensure_pool(),
|
||||
client._ensure_pool(),
|
||||
client._ensure_pool(),
|
||||
)
|
||||
|
||||
# All results should be the same pool instance
|
||||
assert results[0] is results[1] is results[2]
|
||||
assert results[0] is mock_pool
|
||||
# Pool should only be created once despite concurrent calls
|
||||
assert call_count == 1
|
||||
2
backend/tests/crud/syndarix/__init__.py
Normal file
2
backend/tests/crud/syndarix/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/crud/syndarix/__init__.py
|
||||
"""Syndarix CRUD operation tests."""
|
||||
218
backend/tests/crud/syndarix/conftest.py
Normal file
218
backend/tests/crud/syndarix/conftest.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# tests/crud/syndarix/conftest.py
|
||||
"""
|
||||
Shared fixtures for Syndarix CRUD tests.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import date, timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.models.syndarix import (
|
||||
AgentInstance,
|
||||
AgentStatus,
|
||||
AgentType,
|
||||
AutonomyLevel,
|
||||
Issue,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
Project,
|
||||
ProjectStatus,
|
||||
Sprint,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
from app.models.user import User
|
||||
from app.schemas.syndarix import (
|
||||
AgentInstanceCreate,
|
||||
AgentTypeCreate,
|
||||
IssueCreate,
|
||||
ProjectCreate,
|
||||
SprintCreate,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_create_data():
|
||||
"""Return data for creating a project via schema."""
|
||||
return ProjectCreate(
|
||||
name="Test Project",
|
||||
slug="test-project-crud",
|
||||
description="A test project for CRUD testing",
|
||||
autonomy_level=AutonomyLevel.MILESTONE,
|
||||
status=ProjectStatus.ACTIVE,
|
||||
settings={"mcp_servers": ["gitea"]},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_type_create_data():
|
||||
"""Return data for creating an agent type via schema."""
|
||||
return AgentTypeCreate(
|
||||
name="Backend Engineer",
|
||||
slug="backend-engineer-crud",
|
||||
description="Specialized in backend development",
|
||||
expertise=["python", "fastapi", "postgresql"],
|
||||
personality_prompt="You are an expert backend engineer with deep knowledge of Python and FastAPI.",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
fallback_models=["claude-sonnet-4-20250514"],
|
||||
model_params={"temperature": 0.7, "max_tokens": 4096},
|
||||
mcp_servers=["gitea", "file-system"],
|
||||
tool_permissions={"allowed": ["*"], "denied": []},
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sprint_create_data():
|
||||
"""Return data for creating a sprint via schema."""
|
||||
today = date.today()
|
||||
return {
|
||||
"name": "Sprint 1",
|
||||
"number": 1,
|
||||
"goal": "Complete initial setup and core features",
|
||||
"start_date": today,
|
||||
"end_date": today + timedelta(days=14),
|
||||
"status": SprintStatus.PLANNED,
|
||||
"planned_points": 21,
|
||||
"completed_points": 0,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def issue_create_data():
|
||||
"""Return data for creating an issue via schema."""
|
||||
return {
|
||||
"title": "Implement user authentication",
|
||||
"body": "As a user, I want to log in securely so that I can access my account.",
|
||||
"status": IssueStatus.OPEN,
|
||||
"priority": IssuePriority.HIGH,
|
||||
"labels": ["backend", "security"],
|
||||
"story_points": 5,
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_owner_crud(async_test_db):
|
||||
"""Create a test user to be used as project owner in CRUD tests."""
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="crud-owner@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="CRUD",
|
||||
last_name="Owner",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_project_crud(async_test_db, test_owner_crud, project_create_data):
|
||||
"""Create a test project in the database for CRUD tests."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name=project_create_data.name,
|
||||
slug=project_create_data.slug,
|
||||
description=project_create_data.description,
|
||||
autonomy_level=project_create_data.autonomy_level,
|
||||
status=project_create_data.status,
|
||||
settings=project_create_data.settings,
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
session.add(project)
|
||||
await session.commit()
|
||||
await session.refresh(project)
|
||||
return project
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_agent_type_crud(async_test_db, agent_type_create_data):
|
||||
"""Create a test agent type in the database for CRUD tests."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name=agent_type_create_data.name,
|
||||
slug=agent_type_create_data.slug,
|
||||
description=agent_type_create_data.description,
|
||||
expertise=agent_type_create_data.expertise,
|
||||
personality_prompt=agent_type_create_data.personality_prompt,
|
||||
primary_model=agent_type_create_data.primary_model,
|
||||
fallback_models=agent_type_create_data.fallback_models,
|
||||
model_params=agent_type_create_data.model_params,
|
||||
mcp_servers=agent_type_create_data.mcp_servers,
|
||||
tool_permissions=agent_type_create_data.tool_permissions,
|
||||
is_active=agent_type_create_data.is_active,
|
||||
)
|
||||
session.add(agent_type)
|
||||
await session.commit()
|
||||
await session.refresh(agent_type)
|
||||
return agent_type
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_agent_instance_crud(async_test_db, test_project_crud, test_agent_type_crud):
|
||||
"""Create a test agent instance in the database for CRUD tests."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=test_agent_type_crud.id,
|
||||
project_id=test_project_crud.id,
|
||||
status=AgentStatus.IDLE,
|
||||
current_task=None,
|
||||
short_term_memory={},
|
||||
long_term_memory_ref=None,
|
||||
session_id=None,
|
||||
tasks_completed=0,
|
||||
tokens_used=0,
|
||||
cost_incurred=Decimal("0.0000"),
|
||||
)
|
||||
session.add(agent_instance)
|
||||
await session.commit()
|
||||
await session.refresh(agent_instance)
|
||||
return agent_instance
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_sprint_crud(async_test_db, test_project_crud, sprint_create_data):
|
||||
"""Create a test sprint in the database for CRUD tests."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=test_project_crud.id,
|
||||
**sprint_create_data,
|
||||
)
|
||||
session.add(sprint)
|
||||
await session.commit()
|
||||
await session.refresh(sprint)
|
||||
return sprint
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_issue_crud(async_test_db, test_project_crud, issue_create_data):
|
||||
"""Create a test issue in the database for CRUD tests."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=test_project_crud.id,
|
||||
**issue_create_data,
|
||||
)
|
||||
session.add(issue)
|
||||
await session.commit()
|
||||
await session.refresh(issue)
|
||||
return issue
|
||||
386
backend/tests/crud/syndarix/test_agent_instance_crud.py
Normal file
386
backend/tests/crud/syndarix/test_agent_instance_crud.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# tests/crud/syndarix/test_agent_instance_crud.py
|
||||
"""
|
||||
Tests for AgentInstance CRUD operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.syndarix import agent_instance as agent_instance_crud
|
||||
from app.models.syndarix import AgentStatus
|
||||
from app.schemas.syndarix import AgentInstanceCreate, AgentInstanceUpdate
|
||||
|
||||
|
||||
class TestAgentInstanceCreate:
|
||||
"""Tests for agent instance creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_instance_success(self, async_test_db, test_project_crud, test_agent_type_crud):
|
||||
"""Test successfully creating an agent instance."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
instance_data = AgentInstanceCreate(
|
||||
agent_type_id=test_agent_type_crud.id,
|
||||
project_id=test_project_crud.id,
|
||||
status=AgentStatus.IDLE,
|
||||
current_task=None,
|
||||
short_term_memory={"context": "initial"},
|
||||
long_term_memory_ref="project-123/agent-456",
|
||||
session_id="session-abc",
|
||||
)
|
||||
result = await agent_instance_crud.create(session, obj_in=instance_data)
|
||||
|
||||
assert result.id is not None
|
||||
assert result.agent_type_id == test_agent_type_crud.id
|
||||
assert result.project_id == test_project_crud.id
|
||||
assert result.status == AgentStatus.IDLE
|
||||
assert result.short_term_memory == {"context": "initial"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_instance_minimal(self, async_test_db, test_project_crud, test_agent_type_crud):
|
||||
"""Test creating agent instance with minimal fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
instance_data = AgentInstanceCreate(
|
||||
agent_type_id=test_agent_type_crud.id,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
result = await agent_instance_crud.create(session, obj_in=instance_data)
|
||||
|
||||
assert result.status == AgentStatus.IDLE # Default
|
||||
assert result.tasks_completed == 0
|
||||
assert result.tokens_used == 0
|
||||
|
||||
|
||||
class TestAgentInstanceRead:
|
||||
"""Tests for agent instance read operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_instance_by_id(self, async_test_db, test_agent_instance_crud):
|
||||
"""Test getting agent instance by ID."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_instance_crud.get(session, id=str(test_agent_instance_crud.id))
|
||||
|
||||
assert result is not None
|
||||
assert result.id == test_agent_instance_crud.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_instance_by_id_not_found(self, async_test_db):
|
||||
"""Test getting non-existent agent instance returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_instance_crud.get(session, id=str(uuid.uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_details(self, async_test_db, test_agent_instance_crud):
|
||||
"""Test getting agent instance with related details."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_instance_crud.get_with_details(
|
||||
session,
|
||||
instance_id=test_agent_instance_crud.id,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["instance"].id == test_agent_instance_crud.id
|
||||
assert result["agent_type_name"] is not None
|
||||
assert result["project_name"] is not None
|
||||
|
||||
|
||||
class TestAgentInstanceUpdate:
|
||||
"""Tests for agent instance update operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_instance_status(self, async_test_db, test_agent_instance_crud):
|
||||
"""Test updating agent instance status."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
instance = await agent_instance_crud.get(session, id=str(test_agent_instance_crud.id))
|
||||
|
||||
update_data = AgentInstanceUpdate(
|
||||
status=AgentStatus.WORKING,
|
||||
current_task="Processing feature request",
|
||||
)
|
||||
result = await agent_instance_crud.update(session, db_obj=instance, obj_in=update_data)
|
||||
|
||||
assert result.status == AgentStatus.WORKING
|
||||
assert result.current_task == "Processing feature request"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_instance_memory(self, async_test_db, test_agent_instance_crud):
|
||||
"""Test updating agent instance short-term memory."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
instance = await agent_instance_crud.get(session, id=str(test_agent_instance_crud.id))
|
||||
|
||||
new_memory = {"conversation": ["msg1", "msg2"], "decisions": {"key": "value"}}
|
||||
update_data = AgentInstanceUpdate(short_term_memory=new_memory)
|
||||
result = await agent_instance_crud.update(session, db_obj=instance, obj_in=update_data)
|
||||
|
||||
assert result.short_term_memory == new_memory
|
||||
|
||||
|
||||
class TestAgentInstanceStatusUpdate:
|
||||
"""Tests for agent instance status update method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status(self, async_test_db, test_agent_instance_crud):
|
||||
"""Test updating agent instance status via dedicated method."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_instance_crud.update_status(
|
||||
session,
|
||||
instance_id=test_agent_instance_crud.id,
|
||||
status=AgentStatus.WORKING,
|
||||
current_task="Working on feature X",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == AgentStatus.WORKING
|
||||
assert result.current_task == "Working on feature X"
|
||||
assert result.last_activity_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_nonexistent(self, async_test_db):
|
||||
"""Test updating status of non-existent instance returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_instance_crud.update_status(
|
||||
session,
|
||||
instance_id=uuid.uuid4(),
|
||||
status=AgentStatus.WORKING,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAgentInstanceTerminate:
|
||||
"""Tests for agent instance termination."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminate_agent_instance(self, async_test_db, test_project_crud, test_agent_type_crud):
|
||||
"""Test terminating an agent instance."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create an instance to terminate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
instance_data = AgentInstanceCreate(
|
||||
agent_type_id=test_agent_type_crud.id,
|
||||
project_id=test_project_crud.id,
|
||||
status=AgentStatus.WORKING,
|
||||
)
|
||||
created = await agent_instance_crud.create(session, obj_in=instance_data)
|
||||
instance_id = created.id
|
||||
|
||||
# Terminate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_instance_crud.terminate(session, instance_id=instance_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == AgentStatus.TERMINATED
|
||||
assert result.terminated_at is not None
|
||||
assert result.current_task is None
|
||||
assert result.session_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminate_nonexistent_instance(self, async_test_db):
|
||||
"""Test terminating non-existent instance returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_instance_crud.terminate(session, instance_id=uuid.uuid4())
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAgentInstanceMetrics:
|
||||
"""Tests for agent instance metrics operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_task_completion(self, async_test_db, test_agent_instance_crud):
|
||||
"""Test recording task completion with metrics."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_instance_crud.record_task_completion(
|
||||
session,
|
||||
instance_id=test_agent_instance_crud.id,
|
||||
tokens_used=1500,
|
||||
cost_incurred=Decimal("0.0150"),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.tasks_completed == 1
|
||||
assert result.tokens_used == 1500
|
||||
assert result.cost_incurred == Decimal("0.0150")
|
||||
assert result.last_activity_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_multiple_task_completions(self, async_test_db, test_project_crud, test_agent_type_crud):
|
||||
"""Test recording multiple task completions accumulates metrics."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create fresh instance
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
instance_data = AgentInstanceCreate(
|
||||
agent_type_id=test_agent_type_crud.id,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
created = await agent_instance_crud.create(session, obj_in=instance_data)
|
||||
instance_id = created.id
|
||||
|
||||
# Record first task
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await agent_instance_crud.record_task_completion(
|
||||
session,
|
||||
instance_id=instance_id,
|
||||
tokens_used=1000,
|
||||
cost_incurred=Decimal("0.0100"),
|
||||
)
|
||||
|
||||
# Record second task
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_instance_crud.record_task_completion(
|
||||
session,
|
||||
instance_id=instance_id,
|
||||
tokens_used=2000,
|
||||
cost_incurred=Decimal("0.0200"),
|
||||
)
|
||||
|
||||
assert result.tasks_completed == 2
|
||||
assert result.tokens_used == 3000
|
||||
assert result.cost_incurred == Decimal("0.0300")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_metrics(self, async_test_db, test_project_crud, test_agent_instance_crud):
|
||||
"""Test getting aggregated metrics for a project."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_instance_crud.get_project_metrics(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert "total_instances" in result
|
||||
assert "active_instances" in result
|
||||
assert "idle_instances" in result
|
||||
assert "total_tasks_completed" in result
|
||||
assert "total_tokens_used" in result
|
||||
assert "total_cost_incurred" in result
|
||||
|
||||
|
||||
class TestAgentInstanceByProject:
|
||||
"""Tests for getting instances by project."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_project(self, async_test_db, test_project_crud, test_agent_instance_crud):
|
||||
"""Test getting instances by project."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
instances, total = await agent_instance_crud.get_by_project(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert all(i.project_id == test_project_crud.id for i in instances)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_project_with_status(self, async_test_db, test_project_crud, test_agent_type_crud):
|
||||
"""Test getting instances by project filtered by status."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create instances with different statuses
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
idle_instance = AgentInstanceCreate(
|
||||
agent_type_id=test_agent_type_crud.id,
|
||||
project_id=test_project_crud.id,
|
||||
status=AgentStatus.IDLE,
|
||||
)
|
||||
await agent_instance_crud.create(session, obj_in=idle_instance)
|
||||
|
||||
working_instance = AgentInstanceCreate(
|
||||
agent_type_id=test_agent_type_crud.id,
|
||||
project_id=test_project_crud.id,
|
||||
status=AgentStatus.WORKING,
|
||||
)
|
||||
await agent_instance_crud.create(session, obj_in=working_instance)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
instances, total = await agent_instance_crud.get_by_project(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
status=AgentStatus.WORKING,
|
||||
)
|
||||
|
||||
assert all(i.status == AgentStatus.WORKING for i in instances)
|
||||
|
||||
|
||||
class TestAgentInstanceByAgentType:
|
||||
"""Tests for getting instances by agent type."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_agent_type(self, async_test_db, test_agent_type_crud, test_agent_instance_crud):
|
||||
"""Test getting instances by agent type."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
instances = await agent_instance_crud.get_by_agent_type(
|
||||
session,
|
||||
agent_type_id=test_agent_type_crud.id,
|
||||
)
|
||||
|
||||
assert len(instances) >= 1
|
||||
assert all(i.agent_type_id == test_agent_type_crud.id for i in instances)
|
||||
|
||||
|
||||
class TestBulkTerminate:
|
||||
"""Tests for bulk termination of instances."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_terminate_by_project(self, async_test_db, test_project_crud, test_agent_type_crud):
|
||||
"""Test bulk terminating all instances in a project."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple instances
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(3):
|
||||
instance_data = AgentInstanceCreate(
|
||||
agent_type_id=test_agent_type_crud.id,
|
||||
project_id=test_project_crud.id,
|
||||
status=AgentStatus.WORKING if i < 2 else AgentStatus.IDLE,
|
||||
)
|
||||
await agent_instance_crud.create(session, obj_in=instance_data)
|
||||
|
||||
# Bulk terminate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
count = await agent_instance_crud.bulk_terminate_by_project(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
assert count >= 3
|
||||
|
||||
# Verify all are terminated
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
instances, _ = await agent_instance_crud.get_by_project(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
for instance in instances:
|
||||
assert instance.status == AgentStatus.TERMINATED
|
||||
353
backend/tests/crud/syndarix/test_agent_type_crud.py
Normal file
353
backend/tests/crud/syndarix/test_agent_type_crud.py
Normal file
@@ -0,0 +1,353 @@
|
||||
# tests/crud/syndarix/test_agent_type_crud.py
|
||||
"""
|
||||
Tests for AgentType CRUD operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.syndarix import agent_type as agent_type_crud
|
||||
from app.schemas.syndarix import AgentTypeCreate, AgentTypeUpdate
|
||||
|
||||
|
||||
class TestAgentTypeCreate:
|
||||
"""Tests for agent type creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_type_success(self, async_test_db):
|
||||
"""Test successfully creating an agent type."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type_data = AgentTypeCreate(
|
||||
name="QA Engineer",
|
||||
slug="qa-engineer",
|
||||
description="Specialized in testing and quality assurance",
|
||||
expertise=["testing", "pytest", "playwright"],
|
||||
personality_prompt="You are an expert QA engineer...",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
fallback_models=["claude-sonnet-4-20250514"],
|
||||
model_params={"temperature": 0.5},
|
||||
mcp_servers=["gitea"],
|
||||
tool_permissions={"allowed": ["*"]},
|
||||
is_active=True,
|
||||
)
|
||||
result = await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||
|
||||
assert result.id is not None
|
||||
assert result.name == "QA Engineer"
|
||||
assert result.slug == "qa-engineer"
|
||||
assert result.expertise == ["testing", "pytest", "playwright"]
|
||||
assert result.is_active is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_type_duplicate_slug_fails(self, async_test_db, test_agent_type_crud):
|
||||
"""Test creating agent type with duplicate slug raises ValueError."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type_data = AgentTypeCreate(
|
||||
name="Duplicate Agent",
|
||||
slug=test_agent_type_crud.slug, # Duplicate slug
|
||||
personality_prompt="Duplicate",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||
|
||||
assert "already exists" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_agent_type_minimal_fields(self, async_test_db):
|
||||
"""Test creating agent type with minimal required fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type_data = AgentTypeCreate(
|
||||
name="Minimal Agent",
|
||||
slug="minimal-agent",
|
||||
personality_prompt="You are an assistant.",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
result = await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||
|
||||
assert result.name == "Minimal Agent"
|
||||
assert result.expertise == [] # Default
|
||||
assert result.fallback_models == [] # Default
|
||||
assert result.is_active is True # Default
|
||||
|
||||
|
||||
class TestAgentTypeRead:
|
||||
"""Tests for agent type read operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_type_by_id(self, async_test_db, test_agent_type_crud):
|
||||
"""Test getting agent type by ID."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_type_crud.get(session, id=str(test_agent_type_crud.id))
|
||||
|
||||
assert result is not None
|
||||
assert result.id == test_agent_type_crud.id
|
||||
assert result.name == test_agent_type_crud.name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_type_by_id_not_found(self, async_test_db):
|
||||
"""Test getting non-existent agent type returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_type_crud.get(session, id=str(uuid.uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_type_by_slug(self, async_test_db, test_agent_type_crud):
|
||||
"""Test getting agent type by slug."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_type_crud.get_by_slug(session, slug=test_agent_type_crud.slug)
|
||||
|
||||
assert result is not None
|
||||
assert result.slug == test_agent_type_crud.slug
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_agent_type_by_slug_not_found(self, async_test_db):
|
||||
"""Test getting non-existent slug returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_type_crud.get_by_slug(session, slug="non-existent-agent")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAgentTypeUpdate:
|
||||
"""Tests for agent type update operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_type_basic_fields(self, async_test_db, test_agent_type_crud):
|
||||
"""Test updating basic agent type fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type = await agent_type_crud.get(session, id=str(test_agent_type_crud.id))
|
||||
|
||||
update_data = AgentTypeUpdate(
|
||||
name="Updated Agent Name",
|
||||
description="Updated description",
|
||||
)
|
||||
result = await agent_type_crud.update(session, db_obj=agent_type, obj_in=update_data)
|
||||
|
||||
assert result.name == "Updated Agent Name"
|
||||
assert result.description == "Updated description"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_type_expertise(self, async_test_db, test_agent_type_crud):
|
||||
"""Test updating agent type expertise."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type = await agent_type_crud.get(session, id=str(test_agent_type_crud.id))
|
||||
|
||||
update_data = AgentTypeUpdate(
|
||||
expertise=["new-skill", "another-skill"],
|
||||
)
|
||||
result = await agent_type_crud.update(session, db_obj=agent_type, obj_in=update_data)
|
||||
|
||||
assert "new-skill" in result.expertise
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_type_model_params(self, async_test_db, test_agent_type_crud):
|
||||
"""Test updating agent type model parameters."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type = await agent_type_crud.get(session, id=str(test_agent_type_crud.id))
|
||||
|
||||
new_params = {"temperature": 0.9, "max_tokens": 8192}
|
||||
update_data = AgentTypeUpdate(model_params=new_params)
|
||||
result = await agent_type_crud.update(session, db_obj=agent_type, obj_in=update_data)
|
||||
|
||||
assert result.model_params == new_params
|
||||
|
||||
|
||||
class TestAgentTypeDelete:
|
||||
"""Tests for agent type delete operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_agent_type(self, async_test_db):
|
||||
"""Test deleting an agent type."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create an agent type to delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type_data = AgentTypeCreate(
|
||||
name="Delete Me Agent",
|
||||
slug="delete-me-agent",
|
||||
personality_prompt="Delete test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
created = await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||
agent_type_id = created.id
|
||||
|
||||
# Delete the agent type
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_type_crud.remove(session, id=str(agent_type_id))
|
||||
assert result is not None
|
||||
|
||||
# Verify deletion
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
deleted = await agent_type_crud.get(session, id=str(agent_type_id))
|
||||
assert deleted is None
|
||||
|
||||
|
||||
class TestAgentTypeFilters:
|
||||
"""Tests for agent type filtering and search."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_active(self, async_test_db):
|
||||
"""Test filtering agent types by is_active."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create active and inactive agent types
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_type = AgentTypeCreate(
|
||||
name="Active Agent Type",
|
||||
slug="active-agent-type-filter",
|
||||
personality_prompt="Active",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
is_active=True,
|
||||
)
|
||||
await agent_type_crud.create(session, obj_in=active_type)
|
||||
|
||||
inactive_type = AgentTypeCreate(
|
||||
name="Inactive Agent Type",
|
||||
slug="inactive-agent-type-filter",
|
||||
personality_prompt="Inactive",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
is_active=False,
|
||||
)
|
||||
await agent_type_crud.create(session, obj_in=inactive_type)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_types, _ = await agent_type_crud.get_multi_with_filters(
|
||||
session,
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
assert all(at.is_active for at in active_types)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_search(self, async_test_db):
|
||||
"""Test searching agent types by name."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type_data = AgentTypeCreate(
|
||||
name="Searchable Agent Type",
|
||||
slug="searchable-agent-type",
|
||||
description="This is searchable",
|
||||
personality_prompt="Searchable",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_types, total = await agent_type_crud.get_multi_with_filters(
|
||||
session,
|
||||
search="Searchable",
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert any(at.name == "Searchable Agent Type" for at in agent_types)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_pagination(self, async_test_db):
|
||||
"""Test pagination of agent type results."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(5):
|
||||
agent_type_data = AgentTypeCreate(
|
||||
name=f"Page Agent Type {i}",
|
||||
slug=f"page-agent-type-{i}",
|
||||
personality_prompt=f"Page {i}",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
page1, total = await agent_type_crud.get_multi_with_filters(
|
||||
session,
|
||||
skip=0,
|
||||
limit=2,
|
||||
)
|
||||
|
||||
assert len(page1) <= 2
|
||||
|
||||
|
||||
class TestAgentTypeSpecialMethods:
|
||||
"""Tests for special agent type CRUD methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_agent_type(self, async_test_db):
|
||||
"""Test deactivating an agent type."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create an active agent type
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type_data = AgentTypeCreate(
|
||||
name="Deactivate Me",
|
||||
slug="deactivate-me-agent",
|
||||
personality_prompt="Deactivate",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
is_active=True,
|
||||
)
|
||||
created = await agent_type_crud.create(session, obj_in=agent_type_data)
|
||||
agent_type_id = created.id
|
||||
|
||||
# Deactivate
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_type_crud.deactivate(session, agent_type_id=agent_type_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deactivate_nonexistent_agent_type(self, async_test_db):
|
||||
"""Test deactivating non-existent agent type returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_type_crud.deactivate(session, agent_type_id=uuid.uuid4())
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_instance_count(self, async_test_db, test_agent_type_crud, test_agent_instance_crud):
|
||||
"""Test getting agent type with instance count."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_type_crud.get_with_instance_count(
|
||||
session,
|
||||
agent_type_id=test_agent_type_crud.id,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["agent_type"].id == test_agent_type_crud.id
|
||||
assert result["instance_count"] >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_instance_count_not_found(self, async_test_db):
|
||||
"""Test getting non-existent agent type with count returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await agent_type_crud.get_with_instance_count(
|
||||
session,
|
||||
agent_type_id=uuid.uuid4(),
|
||||
)
|
||||
assert result is None
|
||||
556
backend/tests/crud/syndarix/test_issue_crud.py
Normal file
556
backend/tests/crud/syndarix/test_issue_crud.py
Normal file
@@ -0,0 +1,556 @@
|
||||
# tests/crud/syndarix/test_issue_crud.py
|
||||
"""
|
||||
Tests for Issue CRUD operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.syndarix import issue as issue_crud
|
||||
from app.models.syndarix import IssuePriority, IssueStatus, SyncStatus
|
||||
from app.schemas.syndarix import IssueCreate, IssueUpdate
|
||||
|
||||
|
||||
class TestIssueCreate:
|
||||
"""Tests for issue creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_issue_success(self, async_test_db, test_project_crud):
|
||||
"""Test successfully creating an issue."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue_data = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="Test Issue",
|
||||
body="This is a test issue body",
|
||||
status=IssueStatus.OPEN,
|
||||
priority=IssuePriority.HIGH,
|
||||
labels=["bug", "security"],
|
||||
story_points=5,
|
||||
)
|
||||
result = await issue_crud.create(session, obj_in=issue_data)
|
||||
|
||||
assert result.id is not None
|
||||
assert result.title == "Test Issue"
|
||||
assert result.body == "This is a test issue body"
|
||||
assert result.status == IssueStatus.OPEN
|
||||
assert result.priority == IssuePriority.HIGH
|
||||
assert result.labels == ["bug", "security"]
|
||||
assert result.story_points == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_issue_with_external_tracker(self, async_test_db, test_project_crud):
|
||||
"""Test creating issue with external tracker info."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue_data = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="External Issue",
|
||||
external_tracker="gitea",
|
||||
external_id="gitea-123",
|
||||
external_url="https://gitea.example.com/issues/123",
|
||||
external_number=123,
|
||||
)
|
||||
result = await issue_crud.create(session, obj_in=issue_data)
|
||||
|
||||
assert result.external_tracker == "gitea"
|
||||
assert result.external_id == "gitea-123"
|
||||
assert result.external_number == 123
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_issue_minimal(self, async_test_db, test_project_crud):
|
||||
"""Test creating issue with minimal fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue_data = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="Minimal Issue",
|
||||
)
|
||||
result = await issue_crud.create(session, obj_in=issue_data)
|
||||
|
||||
assert result.title == "Minimal Issue"
|
||||
assert result.body == "" # Default
|
||||
assert result.status == IssueStatus.OPEN # Default
|
||||
assert result.priority == IssuePriority.MEDIUM # Default
|
||||
|
||||
|
||||
class TestIssueRead:
|
||||
"""Tests for issue read operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_issue_by_id(self, async_test_db, test_issue_crud):
|
||||
"""Test getting issue by ID."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.get(session, id=str(test_issue_crud.id))
|
||||
|
||||
assert result is not None
|
||||
assert result.id == test_issue_crud.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_issue_by_id_not_found(self, async_test_db):
|
||||
"""Test getting non-existent issue returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.get(session, id=str(uuid.uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_details(self, async_test_db, test_issue_crud):
|
||||
"""Test getting issue with related details."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.get_with_details(
|
||||
session,
|
||||
issue_id=test_issue_crud.id,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["issue"].id == test_issue_crud.id
|
||||
assert result["project_name"] is not None
|
||||
|
||||
|
||||
class TestIssueUpdate:
|
||||
"""Tests for issue update operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_issue_basic_fields(self, async_test_db, test_issue_crud):
|
||||
"""Test updating basic issue fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue = await issue_crud.get(session, id=str(test_issue_crud.id))
|
||||
|
||||
update_data = IssueUpdate(
|
||||
title="Updated Title",
|
||||
body="Updated body content",
|
||||
)
|
||||
result = await issue_crud.update(session, db_obj=issue, obj_in=update_data)
|
||||
|
||||
assert result.title == "Updated Title"
|
||||
assert result.body == "Updated body content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_issue_status(self, async_test_db, test_issue_crud):
|
||||
"""Test updating issue status."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue = await issue_crud.get(session, id=str(test_issue_crud.id))
|
||||
|
||||
update_data = IssueUpdate(status=IssueStatus.IN_PROGRESS)
|
||||
result = await issue_crud.update(session, db_obj=issue, obj_in=update_data)
|
||||
|
||||
assert result.status == IssueStatus.IN_PROGRESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_issue_priority(self, async_test_db, test_issue_crud):
|
||||
"""Test updating issue priority."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue = await issue_crud.get(session, id=str(test_issue_crud.id))
|
||||
|
||||
update_data = IssueUpdate(priority=IssuePriority.CRITICAL)
|
||||
result = await issue_crud.update(session, db_obj=issue, obj_in=update_data)
|
||||
|
||||
assert result.priority == IssuePriority.CRITICAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_issue_labels(self, async_test_db, test_issue_crud):
|
||||
"""Test updating issue labels."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue = await issue_crud.get(session, id=str(test_issue_crud.id))
|
||||
|
||||
update_data = IssueUpdate(labels=["new-label", "updated"])
|
||||
result = await issue_crud.update(session, db_obj=issue, obj_in=update_data)
|
||||
|
||||
assert "new-label" in result.labels
|
||||
|
||||
|
||||
class TestIssueAssignment:
|
||||
"""Tests for issue assignment operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assign_to_agent(self, async_test_db, test_issue_crud, test_agent_instance_crud):
|
||||
"""Test assigning issue to an agent."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.assign_to_agent(
|
||||
session,
|
||||
issue_id=test_issue_crud.id,
|
||||
agent_id=test_agent_instance_crud.id,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.assigned_agent_id == test_agent_instance_crud.id
|
||||
assert result.human_assignee is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unassign_agent(self, async_test_db, test_issue_crud, test_agent_instance_crud):
|
||||
"""Test unassigning agent from issue."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# First assign
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await issue_crud.assign_to_agent(
|
||||
session,
|
||||
issue_id=test_issue_crud.id,
|
||||
agent_id=test_agent_instance_crud.id,
|
||||
)
|
||||
|
||||
# Then unassign
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.assign_to_agent(
|
||||
session,
|
||||
issue_id=test_issue_crud.id,
|
||||
agent_id=None,
|
||||
)
|
||||
|
||||
assert result.assigned_agent_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assign_to_human(self, async_test_db, test_issue_crud):
|
||||
"""Test assigning issue to a human."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.assign_to_human(
|
||||
session,
|
||||
issue_id=test_issue_crud.id,
|
||||
human_assignee="developer@example.com",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.human_assignee == "developer@example.com"
|
||||
assert result.assigned_agent_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assign_to_human_clears_agent(self, async_test_db, test_issue_crud, test_agent_instance_crud):
|
||||
"""Test assigning to human clears agent assignment."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# First assign to agent
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
await issue_crud.assign_to_agent(
|
||||
session,
|
||||
issue_id=test_issue_crud.id,
|
||||
agent_id=test_agent_instance_crud.id,
|
||||
)
|
||||
|
||||
# Then assign to human
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.assign_to_human(
|
||||
session,
|
||||
issue_id=test_issue_crud.id,
|
||||
human_assignee="developer@example.com",
|
||||
)
|
||||
|
||||
assert result.human_assignee == "developer@example.com"
|
||||
assert result.assigned_agent_id is None
|
||||
|
||||
|
||||
class TestIssueLifecycle:
|
||||
"""Tests for issue lifecycle operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_issue(self, async_test_db, test_issue_crud):
|
||||
"""Test closing an issue."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.close_issue(session, issue_id=test_issue_crud.id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == IssueStatus.CLOSED
|
||||
assert result.closed_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reopen_issue(self, async_test_db, test_project_crud):
|
||||
"""Test reopening a closed issue."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create and close an issue
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue_data = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="Issue to Reopen",
|
||||
)
|
||||
created = await issue_crud.create(session, obj_in=issue_data)
|
||||
await issue_crud.close_issue(session, issue_id=created.id)
|
||||
issue_id = created.id
|
||||
|
||||
# Reopen
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.reopen_issue(session, issue_id=issue_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == IssueStatus.OPEN
|
||||
assert result.closed_at is None
|
||||
|
||||
|
||||
class TestIssueByProject:
|
||||
"""Tests for getting issues by project."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_project(self, async_test_db, test_project_crud, test_issue_crud):
|
||||
"""Test getting issues by project."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issues, total = await issue_crud.get_by_project(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert all(i.project_id == test_project_crud.id for i in issues)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_project_with_status(self, async_test_db, test_project_crud):
|
||||
"""Test filtering issues by status."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create issues with different statuses
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
open_issue = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="Open Issue Filter",
|
||||
status=IssueStatus.OPEN,
|
||||
)
|
||||
await issue_crud.create(session, obj_in=open_issue)
|
||||
|
||||
closed_issue = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="Closed Issue Filter",
|
||||
status=IssueStatus.CLOSED,
|
||||
)
|
||||
await issue_crud.create(session, obj_in=closed_issue)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issues, _ = await issue_crud.get_by_project(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
status=IssueStatus.OPEN,
|
||||
)
|
||||
|
||||
assert all(i.status == IssueStatus.OPEN for i in issues)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_project_with_priority(self, async_test_db, test_project_crud):
|
||||
"""Test filtering issues by priority."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
high_issue = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="High Priority Issue",
|
||||
priority=IssuePriority.HIGH,
|
||||
)
|
||||
await issue_crud.create(session, obj_in=high_issue)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issues, _ = await issue_crud.get_by_project(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
priority=IssuePriority.HIGH,
|
||||
)
|
||||
|
||||
assert all(i.priority == IssuePriority.HIGH for i in issues)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_project_with_search(self, async_test_db, test_project_crud):
|
||||
"""Test searching issues by title/body."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
searchable_issue = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="Searchable Unique Title",
|
||||
body="This body contains searchable content",
|
||||
)
|
||||
await issue_crud.create(session, obj_in=searchable_issue)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issues, total = await issue_crud.get_by_project(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
search="Searchable Unique",
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert any(i.title == "Searchable Unique Title" for i in issues)
|
||||
|
||||
|
||||
class TestIssueBySprint:
|
||||
"""Tests for getting issues by sprint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_sprint(self, async_test_db, test_project_crud, test_sprint_crud):
|
||||
"""Test getting issues by sprint."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create issue in sprint
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue_data = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="Sprint Issue",
|
||||
sprint_id=test_sprint_crud.id,
|
||||
)
|
||||
await issue_crud.create(session, obj_in=issue_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issues = await issue_crud.get_by_sprint(
|
||||
session,
|
||||
sprint_id=test_sprint_crud.id,
|
||||
)
|
||||
|
||||
assert len(issues) >= 1
|
||||
assert all(i.sprint_id == test_sprint_crud.id for i in issues)
|
||||
|
||||
|
||||
class TestIssueSyncStatus:
|
||||
"""Tests for issue sync status operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_sync_status(self, async_test_db, test_project_crud):
|
||||
"""Test updating issue sync status."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create issue with external tracker
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue_data = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="Sync Status Issue",
|
||||
external_tracker="gitea",
|
||||
external_id="gitea-456",
|
||||
)
|
||||
created = await issue_crud.create(session, obj_in=issue_data)
|
||||
issue_id = created.id
|
||||
|
||||
# Update sync status
|
||||
now = datetime.now(UTC)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.update_sync_status(
|
||||
session,
|
||||
issue_id=issue_id,
|
||||
sync_status=SyncStatus.PENDING,
|
||||
last_synced_at=now,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.sync_status == SyncStatus.PENDING
|
||||
assert result.last_synced_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_sync(self, async_test_db, test_project_crud):
|
||||
"""Test getting issues pending sync."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create issue with pending sync
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue_data = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="Pending Sync Issue",
|
||||
external_tracker="gitea",
|
||||
external_id="gitea-789",
|
||||
)
|
||||
created = await issue_crud.create(session, obj_in=issue_data)
|
||||
|
||||
# Set to pending
|
||||
await issue_crud.update_sync_status(
|
||||
session,
|
||||
issue_id=created.id,
|
||||
sync_status=SyncStatus.PENDING,
|
||||
)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issues = await issue_crud.get_pending_sync(session)
|
||||
|
||||
assert any(i.sync_status == SyncStatus.PENDING for i in issues)
|
||||
|
||||
|
||||
class TestIssueExternalTracker:
|
||||
"""Tests for external tracker operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_external_id(self, async_test_db, test_project_crud):
|
||||
"""Test getting issue by external tracker ID."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create issue with external ID
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue_data = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title="External ID Issue",
|
||||
external_tracker="github",
|
||||
external_id="github-unique-123",
|
||||
)
|
||||
await issue_crud.create(session, obj_in=issue_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.get_by_external_id(
|
||||
session,
|
||||
external_tracker="github",
|
||||
external_id="github-unique-123",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.external_id == "github-unique-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_external_id_not_found(self, async_test_db):
|
||||
"""Test getting non-existent external ID returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await issue_crud.get_by_external_id(
|
||||
session,
|
||||
external_tracker="gitea",
|
||||
external_id="non-existent",
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIssueStats:
|
||||
"""Tests for issue statistics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_stats(self, async_test_db, test_project_crud):
|
||||
"""Test getting issue statistics for a project."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create issues with various statuses and priorities
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for status in [IssueStatus.OPEN, IssueStatus.IN_PROGRESS, IssueStatus.CLOSED]:
|
||||
issue_data = IssueCreate(
|
||||
project_id=test_project_crud.id,
|
||||
title=f"Stats Issue {status.value}",
|
||||
status=status,
|
||||
story_points=3,
|
||||
)
|
||||
await issue_crud.create(session, obj_in=issue_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
stats = await issue_crud.get_project_stats(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
assert "total" in stats
|
||||
assert "open" in stats
|
||||
assert "in_progress" in stats
|
||||
assert "closed" in stats
|
||||
assert "by_priority" in stats
|
||||
assert "total_story_points" in stats
|
||||
409
backend/tests/crud/syndarix/test_project_crud.py
Normal file
409
backend/tests/crud/syndarix/test_project_crud.py
Normal file
@@ -0,0 +1,409 @@
|
||||
# tests/crud/syndarix/test_project_crud.py
|
||||
"""
|
||||
Tests for Project CRUD operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.syndarix import project as project_crud
|
||||
from app.models.syndarix import AutonomyLevel, ProjectStatus
|
||||
from app.schemas.syndarix import ProjectCreate, ProjectUpdate
|
||||
|
||||
|
||||
class TestProjectCreate:
|
||||
"""Tests for project creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project_success(self, async_test_db, test_owner_crud):
|
||||
"""Test successfully creating a project."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project_data = ProjectCreate(
|
||||
name="New Project",
|
||||
slug="new-project",
|
||||
description="A brand new project",
|
||||
autonomy_level=AutonomyLevel.MILESTONE,
|
||||
status=ProjectStatus.ACTIVE,
|
||||
settings={"key": "value"},
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
result = await project_crud.create(session, obj_in=project_data)
|
||||
|
||||
assert result.id is not None
|
||||
assert result.name == "New Project"
|
||||
assert result.slug == "new-project"
|
||||
assert result.description == "A brand new project"
|
||||
assert result.autonomy_level == AutonomyLevel.MILESTONE
|
||||
assert result.status == ProjectStatus.ACTIVE
|
||||
assert result.settings == {"key": "value"}
|
||||
assert result.owner_id == test_owner_crud.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project_duplicate_slug_fails(self, async_test_db, test_project_crud):
|
||||
"""Test creating project with duplicate slug raises ValueError."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project_data = ProjectCreate(
|
||||
name="Duplicate Project",
|
||||
slug=test_project_crud.slug, # Duplicate slug
|
||||
description="This should fail",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await project_crud.create(session, obj_in=project_data)
|
||||
|
||||
assert "already exists" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project_minimal_fields(self, async_test_db):
|
||||
"""Test creating project with minimal required fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project_data = ProjectCreate(
|
||||
name="Minimal Project",
|
||||
slug="minimal-project",
|
||||
)
|
||||
result = await project_crud.create(session, obj_in=project_data)
|
||||
|
||||
assert result.name == "Minimal Project"
|
||||
assert result.slug == "minimal-project"
|
||||
assert result.autonomy_level == AutonomyLevel.MILESTONE # Default
|
||||
assert result.status == ProjectStatus.ACTIVE # Default
|
||||
|
||||
|
||||
class TestProjectRead:
|
||||
"""Tests for project read operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_by_id(self, async_test_db, test_project_crud):
|
||||
"""Test getting project by ID."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await project_crud.get(session, id=str(test_project_crud.id))
|
||||
|
||||
assert result is not None
|
||||
assert result.id == test_project_crud.id
|
||||
assert result.name == test_project_crud.name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_by_id_not_found(self, async_test_db):
|
||||
"""Test getting non-existent project returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await project_crud.get(session, id=str(uuid.uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_by_slug(self, async_test_db, test_project_crud):
|
||||
"""Test getting project by slug."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await project_crud.get_by_slug(session, slug=test_project_crud.slug)
|
||||
|
||||
assert result is not None
|
||||
assert result.slug == test_project_crud.slug
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project_by_slug_not_found(self, async_test_db):
|
||||
"""Test getting non-existent slug returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await project_crud.get_by_slug(session, slug="non-existent-slug")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestProjectUpdate:
|
||||
"""Tests for project update operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_project_basic_fields(self, async_test_db, test_project_crud):
|
||||
"""Test updating basic project fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project = await project_crud.get(session, id=str(test_project_crud.id))
|
||||
|
||||
update_data = ProjectUpdate(
|
||||
name="Updated Project Name",
|
||||
description="Updated description",
|
||||
)
|
||||
result = await project_crud.update(session, db_obj=project, obj_in=update_data)
|
||||
|
||||
assert result.name == "Updated Project Name"
|
||||
assert result.description == "Updated description"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_project_status(self, async_test_db, test_project_crud):
|
||||
"""Test updating project status."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project = await project_crud.get(session, id=str(test_project_crud.id))
|
||||
|
||||
update_data = ProjectUpdate(status=ProjectStatus.PAUSED)
|
||||
result = await project_crud.update(session, db_obj=project, obj_in=update_data)
|
||||
|
||||
assert result.status == ProjectStatus.PAUSED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_project_autonomy_level(self, async_test_db, test_project_crud):
|
||||
"""Test updating project autonomy level."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project = await project_crud.get(session, id=str(test_project_crud.id))
|
||||
|
||||
update_data = ProjectUpdate(autonomy_level=AutonomyLevel.AUTONOMOUS)
|
||||
result = await project_crud.update(session, db_obj=project, obj_in=update_data)
|
||||
|
||||
assert result.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_project_settings(self, async_test_db, test_project_crud):
|
||||
"""Test updating project settings."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project = await project_crud.get(session, id=str(test_project_crud.id))
|
||||
|
||||
new_settings = {"mcp_servers": ["gitea", "slack"], "webhook_url": "https://example.com"}
|
||||
update_data = ProjectUpdate(settings=new_settings)
|
||||
result = await project_crud.update(session, db_obj=project, obj_in=update_data)
|
||||
|
||||
assert result.settings == new_settings
|
||||
|
||||
|
||||
class TestProjectDelete:
|
||||
"""Tests for project delete operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_project(self, async_test_db, test_owner_crud):
|
||||
"""Test deleting a project."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create a project to delete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project_data = ProjectCreate(
|
||||
name="Delete Me",
|
||||
slug="delete-me-project",
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
created = await project_crud.create(session, obj_in=project_data)
|
||||
project_id = created.id
|
||||
|
||||
# Delete the project
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await project_crud.remove(session, id=str(project_id))
|
||||
assert result is not None
|
||||
assert result.id == project_id
|
||||
|
||||
# Verify deletion
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
deleted = await project_crud.get(session, id=str(project_id))
|
||||
assert deleted is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_project(self, async_test_db):
|
||||
"""Test deleting non-existent project returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await project_crud.remove(session, id=str(uuid.uuid4()))
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestProjectFilters:
|
||||
"""Tests for project filtering and search."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_status(self, async_test_db, test_owner_crud):
|
||||
"""Test filtering projects by status."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple projects with different statuses
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i, status in enumerate(ProjectStatus):
|
||||
project_data = ProjectCreate(
|
||||
name=f"Project {status.value}",
|
||||
slug=f"project-filter-{status.value}-{i}",
|
||||
status=status,
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
await project_crud.create(session, obj_in=project_data)
|
||||
|
||||
# Filter by ACTIVE status
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
projects, total = await project_crud.get_multi_with_filters(
|
||||
session,
|
||||
status=ProjectStatus.ACTIVE,
|
||||
)
|
||||
|
||||
assert all(p.status == ProjectStatus.ACTIVE for p in projects)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_search(self, async_test_db, test_owner_crud):
|
||||
"""Test searching projects by name/slug."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project_data = ProjectCreate(
|
||||
name="Searchable Project",
|
||||
slug="searchable-unique-slug",
|
||||
description="This project is searchable",
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
await project_crud.create(session, obj_in=project_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
projects, total = await project_crud.get_multi_with_filters(
|
||||
session,
|
||||
search="Searchable",
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert any(p.name == "Searchable Project" for p in projects)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_owner(self, async_test_db, test_owner_crud, test_project_crud):
|
||||
"""Test filtering projects by owner."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
projects, total = await project_crud.get_multi_with_filters(
|
||||
session,
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert all(p.owner_id == test_owner_crud.id for p in projects)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_pagination(self, async_test_db, test_owner_crud):
|
||||
"""Test pagination of project results."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create multiple projects
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(5):
|
||||
project_data = ProjectCreate(
|
||||
name=f"Page Project {i}",
|
||||
slug=f"page-project-{i}",
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
await project_crud.create(session, obj_in=project_data)
|
||||
|
||||
# Get first page
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
page1, total = await project_crud.get_multi_with_filters(
|
||||
session,
|
||||
skip=0,
|
||||
limit=2,
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
|
||||
assert len(page1) <= 2
|
||||
assert total >= 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multi_with_filters_sorting(self, async_test_db, test_owner_crud):
|
||||
"""Test sorting project results."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i, name in enumerate(["Charlie", "Alice", "Bob"]):
|
||||
project_data = ProjectCreate(
|
||||
name=name,
|
||||
slug=f"sort-project-{name.lower()}",
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
await project_crud.create(session, obj_in=project_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
projects, _ = await project_crud.get_multi_with_filters(
|
||||
session,
|
||||
sort_by="name",
|
||||
sort_order="asc",
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
|
||||
names = [p.name for p in projects if p.name in ["Alice", "Bob", "Charlie"]]
|
||||
assert names == sorted(names)
|
||||
|
||||
|
||||
class TestProjectSpecialMethods:
|
||||
"""Tests for special project CRUD methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_project(self, async_test_db, test_project_crud):
|
||||
"""Test archiving a project."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await project_crud.archive_project(session, project_id=test_project_crud.id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == ProjectStatus.ARCHIVED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_nonexistent_project(self, async_test_db):
|
||||
"""Test archiving non-existent project returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await project_crud.archive_project(session, project_id=uuid.uuid4())
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_projects_by_owner(self, async_test_db, test_owner_crud, test_project_crud):
|
||||
"""Test getting all projects by owner."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
projects = await project_crud.get_projects_by_owner(
|
||||
session,
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
|
||||
assert len(projects) >= 1
|
||||
assert all(p.owner_id == test_owner_crud.id for p in projects)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_projects_by_owner_with_status(self, async_test_db, test_owner_crud):
|
||||
"""Test getting projects by owner filtered by status."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Create projects with different statuses
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
active_project = ProjectCreate(
|
||||
name="Active Owner Project",
|
||||
slug="active-owner-project",
|
||||
status=ProjectStatus.ACTIVE,
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
await project_crud.create(session, obj_in=active_project)
|
||||
|
||||
paused_project = ProjectCreate(
|
||||
name="Paused Owner Project",
|
||||
slug="paused-owner-project",
|
||||
status=ProjectStatus.PAUSED,
|
||||
owner_id=test_owner_crud.id,
|
||||
)
|
||||
await project_crud.create(session, obj_in=paused_project)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
projects = await project_crud.get_projects_by_owner(
|
||||
session,
|
||||
owner_id=test_owner_crud.id,
|
||||
status=ProjectStatus.ACTIVE,
|
||||
)
|
||||
|
||||
assert all(p.status == ProjectStatus.ACTIVE for p in projects)
|
||||
524
backend/tests/crud/syndarix/test_sprint_crud.py
Normal file
524
backend/tests/crud/syndarix/test_sprint_crud.py
Normal file
@@ -0,0 +1,524 @@
|
||||
# tests/crud/syndarix/test_sprint_crud.py
|
||||
"""
|
||||
Tests for Sprint CRUD operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.crud.syndarix import sprint as sprint_crud
|
||||
from app.models.syndarix import SprintStatus
|
||||
from app.schemas.syndarix import SprintCreate, SprintUpdate
|
||||
|
||||
|
||||
class TestSprintCreate:
|
||||
"""Tests for sprint creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_sprint_success(self, async_test_db, test_project_crud):
|
||||
"""Test successfully creating a sprint."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Sprint 1",
|
||||
number=1,
|
||||
goal="Complete initial setup",
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.PLANNED,
|
||||
planned_points=21,
|
||||
)
|
||||
result = await sprint_crud.create(session, obj_in=sprint_data)
|
||||
|
||||
assert result.id is not None
|
||||
assert result.name == "Sprint 1"
|
||||
assert result.number == 1
|
||||
assert result.goal == "Complete initial setup"
|
||||
assert result.status == SprintStatus.PLANNED
|
||||
assert result.planned_points == 21
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_sprint_minimal(self, async_test_db, test_project_crud):
|
||||
"""Test creating sprint with minimal fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Minimal Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
result = await sprint_crud.create(session, obj_in=sprint_data)
|
||||
|
||||
assert result.name == "Minimal Sprint"
|
||||
assert result.status == SprintStatus.PLANNED # Default
|
||||
assert result.goal is None
|
||||
assert result.planned_points is None
|
||||
|
||||
|
||||
class TestSprintRead:
|
||||
"""Tests for sprint read operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sprint_by_id(self, async_test_db, test_sprint_crud):
|
||||
"""Test getting sprint by ID."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await sprint_crud.get(session, id=str(test_sprint_crud.id))
|
||||
|
||||
assert result is not None
|
||||
assert result.id == test_sprint_crud.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sprint_by_id_not_found(self, async_test_db):
|
||||
"""Test getting non-existent sprint returns None."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await sprint_crud.get(session, id=str(uuid.uuid4()))
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_with_details(self, async_test_db, test_sprint_crud):
|
||||
"""Test getting sprint with related details."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await sprint_crud.get_with_details(
|
||||
session,
|
||||
sprint_id=test_sprint_crud.id,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result["sprint"].id == test_sprint_crud.id
|
||||
assert result["project_name"] is not None
|
||||
assert "issue_count" in result
|
||||
assert "open_issues" in result
|
||||
assert "completed_issues" in result
|
||||
|
||||
|
||||
class TestSprintUpdate:
|
||||
"""Tests for sprint update operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_sprint_basic_fields(self, async_test_db, test_sprint_crud):
|
||||
"""Test updating basic sprint fields."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint = await sprint_crud.get(session, id=str(test_sprint_crud.id))
|
||||
|
||||
update_data = SprintUpdate(
|
||||
name="Updated Sprint Name",
|
||||
goal="Updated goal",
|
||||
)
|
||||
result = await sprint_crud.update(session, db_obj=sprint, obj_in=update_data)
|
||||
|
||||
assert result.name == "Updated Sprint Name"
|
||||
assert result.goal == "Updated goal"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_sprint_dates(self, async_test_db, test_sprint_crud):
|
||||
"""Test updating sprint dates."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint = await sprint_crud.get(session, id=str(test_sprint_crud.id))
|
||||
|
||||
update_data = SprintUpdate(
|
||||
start_date=today + timedelta(days=1),
|
||||
end_date=today + timedelta(days=21),
|
||||
)
|
||||
result = await sprint_crud.update(session, db_obj=sprint, obj_in=update_data)
|
||||
|
||||
assert result.start_date == today + timedelta(days=1)
|
||||
assert result.end_date == today + timedelta(days=21)
|
||||
|
||||
|
||||
class TestSprintLifecycle:
|
||||
"""Tests for sprint lifecycle operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sprint(self, async_test_db, test_sprint_crud):
|
||||
"""Test starting a planned sprint."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await sprint_crud.start_sprint(
|
||||
session,
|
||||
sprint_id=test_sprint_crud.id,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == SprintStatus.ACTIVE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sprint_with_custom_date(self, async_test_db, test_project_crud):
|
||||
"""Test starting sprint with custom start date."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Create a planned sprint
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Start Date Sprint",
|
||||
number=10,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.PLANNED,
|
||||
)
|
||||
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||
sprint_id = created.id
|
||||
|
||||
# Start with custom date
|
||||
new_start = today + timedelta(days=2)
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await sprint_crud.start_sprint(
|
||||
session,
|
||||
sprint_id=sprint_id,
|
||||
start_date=new_start,
|
||||
)
|
||||
|
||||
assert result.status == SprintStatus.ACTIVE
|
||||
assert result.start_date == new_start
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_sprint_already_active_fails(self, async_test_db, test_project_crud):
|
||||
"""Test starting an already active sprint raises ValueError."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Create and start a sprint
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Already Active Sprint",
|
||||
number=20,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.ACTIVE,
|
||||
)
|
||||
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||
sprint_id = created.id
|
||||
|
||||
# Try to start again
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await sprint_crud.start_sprint(session, sprint_id=sprint_id)
|
||||
|
||||
assert "cannot start sprint" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_sprint(self, async_test_db, test_project_crud):
|
||||
"""Test completing an active sprint."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Create an active sprint
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Complete Me Sprint",
|
||||
number=30,
|
||||
start_date=today - timedelta(days=14),
|
||||
end_date=today,
|
||||
status=SprintStatus.ACTIVE,
|
||||
planned_points=21,
|
||||
)
|
||||
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||
sprint_id = created.id
|
||||
|
||||
# Complete
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await sprint_crud.complete_sprint(session, sprint_id=sprint_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == SprintStatus.COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_planned_sprint_fails(self, async_test_db, test_project_crud):
|
||||
"""Test completing a planned sprint raises ValueError."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Planned Sprint",
|
||||
number=40,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.PLANNED,
|
||||
)
|
||||
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||
sprint_id = created.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await sprint_crud.complete_sprint(session, sprint_id=sprint_id)
|
||||
|
||||
assert "cannot complete sprint" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_sprint(self, async_test_db, test_project_crud):
|
||||
"""Test cancelling a sprint."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Cancel Me Sprint",
|
||||
number=50,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.ACTIVE,
|
||||
)
|
||||
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||
sprint_id = created.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await sprint_crud.cancel_sprint(session, sprint_id=sprint_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == SprintStatus.CANCELLED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_completed_sprint_fails(self, async_test_db, test_project_crud):
|
||||
"""Test cancelling a completed sprint raises ValueError."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Completed Sprint",
|
||||
number=60,
|
||||
start_date=today - timedelta(days=14),
|
||||
end_date=today,
|
||||
status=SprintStatus.COMPLETED,
|
||||
)
|
||||
created = await sprint_crud.create(session, obj_in=sprint_data)
|
||||
sprint_id = created.id
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await sprint_crud.cancel_sprint(session, sprint_id=sprint_id)
|
||||
|
||||
assert "cannot cancel sprint" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestSprintByProject:
|
||||
"""Tests for getting sprints by project."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_project(self, async_test_db, test_project_crud, test_sprint_crud):
|
||||
"""Test getting sprints by project."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprints, total = await sprint_crud.get_by_project(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
assert all(s.project_id == test_project_crud.id for s in sprints)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_by_project_with_status(self, async_test_db, test_project_crud):
|
||||
"""Test filtering sprints by status."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Create sprints with different statuses
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
planned_sprint = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Planned Filter Sprint",
|
||||
number=70,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.PLANNED,
|
||||
)
|
||||
await sprint_crud.create(session, obj_in=planned_sprint)
|
||||
|
||||
active_sprint = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Active Filter Sprint",
|
||||
number=71,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.ACTIVE,
|
||||
)
|
||||
await sprint_crud.create(session, obj_in=active_sprint)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprints, _ = await sprint_crud.get_by_project(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
status=SprintStatus.ACTIVE,
|
||||
)
|
||||
|
||||
assert all(s.status == SprintStatus.ACTIVE for s in sprints)
|
||||
|
||||
|
||||
class TestSprintActiveSprint:
|
||||
"""Tests for active sprint operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_sprint(self, async_test_db, test_project_crud):
|
||||
"""Test getting active sprint for a project."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Create an active sprint
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name="Active Sprint",
|
||||
number=80,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.ACTIVE,
|
||||
)
|
||||
await sprint_crud.create(session, obj_in=sprint_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await sprint_crud.get_active_sprint(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == SprintStatus.ACTIVE
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_active_sprint_none(self, async_test_db, test_project_crud):
|
||||
"""Test getting active sprint when none exists."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
# Note: test_sprint_crud has PLANNED status by default
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
result = await sprint_crud.get_active_sprint(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
# May or may not be None depending on other tests
|
||||
if result is not None:
|
||||
assert result.status == SprintStatus.ACTIVE
|
||||
|
||||
|
||||
class TestSprintNextNumber:
|
||||
"""Tests for getting next sprint number."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_next_sprint_number(self, async_test_db, test_project_crud):
|
||||
"""Test getting next sprint number."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Create sprints with numbers
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(1, 4):
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name=f"Number Sprint {i}",
|
||||
number=i,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
await sprint_crud.create(session, obj_in=sprint_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
next_number = await sprint_crud.get_next_sprint_number(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
assert next_number >= 4
|
||||
|
||||
|
||||
class TestSprintVelocity:
|
||||
"""Tests for sprint velocity operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_velocity(self, async_test_db, test_project_crud):
|
||||
"""Test getting velocity data for completed sprints."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Create completed sprints with points
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
for i in range(1, 4):
|
||||
sprint_data = SprintCreate(
|
||||
project_id=test_project_crud.id,
|
||||
name=f"Velocity Sprint {i}",
|
||||
number=100 + i,
|
||||
start_date=today - timedelta(days=14 * i),
|
||||
end_date=today - timedelta(days=14 * (i - 1)),
|
||||
status=SprintStatus.COMPLETED,
|
||||
planned_points=20,
|
||||
completed_points=15 + i,
|
||||
)
|
||||
await sprint_crud.create(session, obj_in=sprint_data)
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
velocity_data = await sprint_crud.get_velocity(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
limit=5,
|
||||
)
|
||||
|
||||
assert len(velocity_data) >= 1
|
||||
for data in velocity_data:
|
||||
assert "sprint_number" in data
|
||||
assert "sprint_name" in data
|
||||
assert "planned_points" in data
|
||||
assert "completed_points" in data
|
||||
assert "velocity" in data
|
||||
|
||||
|
||||
class TestSprintWithIssueCounts:
|
||||
"""Tests for getting sprints with issue counts."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sprints_with_issue_counts(self, async_test_db, test_project_crud, test_sprint_crud):
|
||||
"""Test getting sprints with issue counts."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
results, total = await sprint_crud.get_sprints_with_issue_counts(
|
||||
session,
|
||||
project_id=test_project_crud.id,
|
||||
)
|
||||
|
||||
assert total >= 1
|
||||
for result in results:
|
||||
assert "sprint" in result
|
||||
assert "issue_count" in result
|
||||
assert "open_issues" in result
|
||||
assert "completed_issues" in result
|
||||
2
backend/tests/models/syndarix/__init__.py
Normal file
2
backend/tests/models/syndarix/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/models/syndarix/__init__.py
|
||||
"""Syndarix model unit tests."""
|
||||
192
backend/tests/models/syndarix/conftest.py
Normal file
192
backend/tests/models/syndarix/conftest.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# tests/models/syndarix/conftest.py
|
||||
"""
|
||||
Shared fixtures for Syndarix model tests.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.models.syndarix import (
|
||||
AgentInstance,
|
||||
AgentStatus,
|
||||
AgentType,
|
||||
AutonomyLevel,
|
||||
Issue,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
Project,
|
||||
ProjectStatus,
|
||||
Sprint,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_project_data():
|
||||
"""Return sample project data for testing."""
|
||||
return {
|
||||
"name": "Test Project",
|
||||
"slug": "test-project",
|
||||
"description": "A test project for unit testing",
|
||||
"autonomy_level": AutonomyLevel.MILESTONE,
|
||||
"status": ProjectStatus.ACTIVE,
|
||||
"settings": {"mcp_servers": ["gitea", "slack"]},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_type_data():
|
||||
"""Return sample agent type data for testing."""
|
||||
return {
|
||||
"name": "Backend Engineer",
|
||||
"slug": "backend-engineer",
|
||||
"description": "Specialized in backend development",
|
||||
"expertise": ["python", "fastapi", "postgresql"],
|
||||
"personality_prompt": "You are an expert backend engineer...",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"fallback_models": ["claude-sonnet-4-20250514"],
|
||||
"model_params": {"temperature": 0.7, "max_tokens": 4096},
|
||||
"mcp_servers": ["gitea", "file-system"],
|
||||
"tool_permissions": {"allowed": ["*"], "denied": []},
|
||||
"is_active": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sprint_data():
|
||||
"""Return sample sprint data for testing."""
|
||||
today = date.today()
|
||||
return {
|
||||
"name": "Sprint 1",
|
||||
"number": 1,
|
||||
"goal": "Complete initial setup and core features",
|
||||
"start_date": today,
|
||||
"end_date": today + timedelta(days=14),
|
||||
"status": SprintStatus.PLANNED,
|
||||
"planned_points": 21,
|
||||
"completed_points": 0,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_issue_data():
|
||||
"""Return sample issue data for testing."""
|
||||
return {
|
||||
"title": "Implement user authentication",
|
||||
"body": "As a user, I want to log in securely...",
|
||||
"status": IssueStatus.OPEN,
|
||||
"priority": IssuePriority.HIGH,
|
||||
"labels": ["backend", "security"],
|
||||
"story_points": 5,
|
||||
}
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_owner(async_test_db):
|
||||
"""Create a test user to be used as project owner."""
|
||||
from app.core.auth import get_password_hash
|
||||
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email="owner@example.com",
|
||||
password_hash=get_password_hash("TestPassword123!"),
|
||||
first_name="Test",
|
||||
last_name="Owner",
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_project(async_test_db, test_owner, sample_project_data):
|
||||
"""Create a test project in the database."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
owner_id=test_owner.id,
|
||||
**sample_project_data,
|
||||
)
|
||||
session.add(project)
|
||||
await session.commit()
|
||||
await session.refresh(project)
|
||||
return project
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_agent_type(async_test_db, sample_agent_type_data):
|
||||
"""Create a test agent type in the database."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
**sample_agent_type_data,
|
||||
)
|
||||
session.add(agent_type)
|
||||
await session.commit()
|
||||
await session.refresh(agent_type)
|
||||
return agent_type
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_agent_instance(async_test_db, test_project, test_agent_type):
|
||||
"""Create a test agent instance in the database."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
agent_instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=test_agent_type.id,
|
||||
project_id=test_project.id,
|
||||
status=AgentStatus.IDLE,
|
||||
current_task=None,
|
||||
short_term_memory={},
|
||||
long_term_memory_ref=None,
|
||||
session_id=None,
|
||||
)
|
||||
session.add(agent_instance)
|
||||
await session.commit()
|
||||
await session.refresh(agent_instance)
|
||||
return agent_instance
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_sprint(async_test_db, test_project, sample_sprint_data):
|
||||
"""Create a test sprint in the database."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=test_project.id,
|
||||
**sample_sprint_data,
|
||||
)
|
||||
session.add(sprint)
|
||||
await session.commit()
|
||||
await session.refresh(sprint)
|
||||
return sprint
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_issue(async_test_db, test_project, sample_issue_data):
|
||||
"""Create a test issue in the database."""
|
||||
_test_engine, AsyncTestingSessionLocal = async_test_db
|
||||
async with AsyncTestingSessionLocal() as session:
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=test_project.id,
|
||||
**sample_issue_data,
|
||||
)
|
||||
session.add(issue)
|
||||
await session.commit()
|
||||
await session.refresh(issue)
|
||||
return issue
|
||||
424
backend/tests/models/syndarix/test_agent_instance.py
Normal file
424
backend/tests/models/syndarix/test_agent_instance.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# tests/models/syndarix/test_agent_instance.py
|
||||
"""
|
||||
Unit tests for the AgentInstance model.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.syndarix import (
|
||||
AgentInstance,
|
||||
AgentStatus,
|
||||
AgentType,
|
||||
Project,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentInstanceModel:
|
||||
"""Tests for AgentInstance model creation and fields."""
|
||||
|
||||
def test_create_agent_instance_with_required_fields(self, db_session):
|
||||
"""Test creating an agent instance with only required fields."""
|
||||
# First create dependencies
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Project",
|
||||
slug="test-project-instance",
|
||||
)
|
||||
db_session.add(project)
|
||||
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Agent",
|
||||
slug="test-agent-instance",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
# Create agent instance
|
||||
instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
)
|
||||
db_session.add(instance)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentInstance).filter_by(project_id=project.id).first()
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.agent_type_id == agent_type.id
|
||||
assert retrieved.project_id == project.id
|
||||
assert retrieved.status == AgentStatus.IDLE # Default
|
||||
assert retrieved.current_task is None
|
||||
assert retrieved.short_term_memory == {}
|
||||
assert retrieved.long_term_memory_ref is None
|
||||
assert retrieved.session_id is None
|
||||
assert retrieved.tasks_completed == 0
|
||||
assert retrieved.tokens_used == 0
|
||||
assert retrieved.cost_incurred == Decimal("0")
|
||||
|
||||
def test_create_agent_instance_with_all_fields(self, db_session):
|
||||
"""Test creating an agent instance with all optional fields."""
|
||||
# First create dependencies
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Full Project",
|
||||
slug="full-project-instance",
|
||||
)
|
||||
db_session.add(project)
|
||||
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Full Agent",
|
||||
slug="full-agent-instance",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
instance_id = uuid.uuid4()
|
||||
now = datetime.now(UTC)
|
||||
|
||||
instance = AgentInstance(
|
||||
id=instance_id,
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
status=AgentStatus.WORKING,
|
||||
current_task="Implementing user authentication",
|
||||
short_term_memory={"context": "Working on auth", "recent_files": ["auth.py"]},
|
||||
long_term_memory_ref="project-123/agent-456",
|
||||
session_id="session-abc-123",
|
||||
last_activity_at=now,
|
||||
tasks_completed=5,
|
||||
tokens_used=10000,
|
||||
cost_incurred=Decimal("0.5000"),
|
||||
)
|
||||
db_session.add(instance)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentInstance).filter_by(id=instance_id).first()
|
||||
|
||||
assert retrieved.status == AgentStatus.WORKING
|
||||
assert retrieved.current_task == "Implementing user authentication"
|
||||
assert retrieved.short_term_memory == {"context": "Working on auth", "recent_files": ["auth.py"]}
|
||||
assert retrieved.long_term_memory_ref == "project-123/agent-456"
|
||||
assert retrieved.session_id == "session-abc-123"
|
||||
assert retrieved.tasks_completed == 5
|
||||
assert retrieved.tokens_used == 10000
|
||||
assert retrieved.cost_incurred == Decimal("0.5000")
|
||||
|
||||
def test_agent_instance_timestamps(self, db_session):
|
||||
"""Test that timestamps are automatically set."""
|
||||
project = Project(id=uuid.uuid4(), name="Timestamp Project", slug="timestamp-project-ai")
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Timestamp Agent",
|
||||
slug="timestamp-agent-ai",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
)
|
||||
db_session.add(instance)
|
||||
db_session.commit()
|
||||
|
||||
assert isinstance(instance.created_at, datetime)
|
||||
assert isinstance(instance.updated_at, datetime)
|
||||
|
||||
def test_agent_instance_string_representation(self, db_session):
|
||||
"""Test the string representation of an agent instance."""
|
||||
project = Project(id=uuid.uuid4(), name="Repr Project", slug="repr-project-ai")
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Repr Agent",
|
||||
slug="repr-agent-ai",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
instance_id = uuid.uuid4()
|
||||
instance = AgentInstance(
|
||||
id=instance_id,
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
status=AgentStatus.IDLE,
|
||||
)
|
||||
|
||||
repr_str = repr(instance)
|
||||
assert str(instance_id) in repr_str
|
||||
assert str(agent_type.id) in repr_str
|
||||
assert str(project.id) in repr_str
|
||||
assert "idle" in repr_str
|
||||
|
||||
|
||||
class TestAgentInstanceStatus:
|
||||
"""Tests for AgentInstance status transitions."""
|
||||
|
||||
def test_all_agent_statuses(self, db_session):
|
||||
"""Test that all agent statuses can be stored."""
|
||||
project = Project(id=uuid.uuid4(), name="Status Project", slug="status-project-ai")
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Status Agent",
|
||||
slug="status-agent-ai",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
for status in AgentStatus:
|
||||
instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
status=status,
|
||||
)
|
||||
db_session.add(instance)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||
assert retrieved.status == status
|
||||
|
||||
def test_status_update(self, db_session):
|
||||
"""Test updating agent instance status."""
|
||||
project = Project(id=uuid.uuid4(), name="Update Status Project", slug="update-status-project-ai")
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Update Status Agent",
|
||||
slug="update-status-agent-ai",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
status=AgentStatus.IDLE,
|
||||
)
|
||||
db_session.add(instance)
|
||||
db_session.commit()
|
||||
|
||||
# Update to WORKING
|
||||
instance.status = AgentStatus.WORKING
|
||||
instance.current_task = "Processing feature request"
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||
assert retrieved.status == AgentStatus.WORKING
|
||||
assert retrieved.current_task == "Processing feature request"
|
||||
|
||||
def test_terminate_agent_instance(self, db_session):
|
||||
"""Test terminating an agent instance."""
|
||||
project = Project(id=uuid.uuid4(), name="Terminate Project", slug="terminate-project-ai")
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Terminate Agent",
|
||||
slug="terminate-agent-ai",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
status=AgentStatus.WORKING,
|
||||
current_task="Working on something",
|
||||
session_id="active-session",
|
||||
)
|
||||
db_session.add(instance)
|
||||
db_session.commit()
|
||||
|
||||
# Terminate
|
||||
now = datetime.now(UTC)
|
||||
instance.status = AgentStatus.TERMINATED
|
||||
instance.terminated_at = now
|
||||
instance.current_task = None
|
||||
instance.session_id = None
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||
assert retrieved.status == AgentStatus.TERMINATED
|
||||
assert retrieved.terminated_at is not None
|
||||
assert retrieved.current_task is None
|
||||
assert retrieved.session_id is None
|
||||
|
||||
|
||||
class TestAgentInstanceMetrics:
|
||||
"""Tests for AgentInstance usage metrics."""
|
||||
|
||||
def test_increment_metrics(self, db_session):
|
||||
"""Test incrementing usage metrics."""
|
||||
project = Project(id=uuid.uuid4(), name="Metrics Project", slug="metrics-project-ai")
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Metrics Agent",
|
||||
slug="metrics-agent-ai",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
)
|
||||
db_session.add(instance)
|
||||
db_session.commit()
|
||||
|
||||
# Record task completion
|
||||
instance.tasks_completed += 1
|
||||
instance.tokens_used += 1500
|
||||
instance.cost_incurred += Decimal("0.0150")
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||
assert retrieved.tasks_completed == 1
|
||||
assert retrieved.tokens_used == 1500
|
||||
assert retrieved.cost_incurred == Decimal("0.0150")
|
||||
|
||||
# Record another task
|
||||
retrieved.tasks_completed += 1
|
||||
retrieved.tokens_used += 2500
|
||||
retrieved.cost_incurred += Decimal("0.0250")
|
||||
db_session.commit()
|
||||
|
||||
updated = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||
assert updated.tasks_completed == 2
|
||||
assert updated.tokens_used == 4000
|
||||
assert updated.cost_incurred == Decimal("0.0400")
|
||||
|
||||
def test_large_token_count(self, db_session):
|
||||
"""Test handling large token counts."""
|
||||
project = Project(id=uuid.uuid4(), name="Large Tokens Project", slug="large-tokens-project-ai")
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Large Tokens Agent",
|
||||
slug="large-tokens-agent-ai",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
tokens_used=10_000_000_000, # 10 billion tokens
|
||||
cost_incurred=Decimal("100000.0000"), # $100,000
|
||||
)
|
||||
db_session.add(instance)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||
assert retrieved.tokens_used == 10_000_000_000
|
||||
assert retrieved.cost_incurred == Decimal("100000.0000")
|
||||
|
||||
|
||||
class TestAgentInstanceShortTermMemory:
|
||||
"""Tests for AgentInstance short-term memory JSON field."""
|
||||
|
||||
def test_store_complex_memory(self, db_session):
|
||||
"""Test storing complex short-term memory."""
|
||||
project = Project(id=uuid.uuid4(), name="Memory Project", slug="memory-project-ai")
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Memory Agent",
|
||||
slug="memory-agent-ai",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
memory = {
|
||||
"conversation_history": [
|
||||
{"role": "user", "content": "Implement feature X"},
|
||||
{"role": "assistant", "content": "I'll start by..."},
|
||||
],
|
||||
"recent_files": ["auth.py", "models.py", "test_auth.py"],
|
||||
"decisions": {
|
||||
"architecture": "Use repository pattern",
|
||||
"testing": "TDD approach",
|
||||
},
|
||||
"blockers": [],
|
||||
"context_tokens": 2048,
|
||||
}
|
||||
|
||||
instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
short_term_memory=memory,
|
||||
)
|
||||
db_session.add(instance)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||
assert retrieved.short_term_memory == memory
|
||||
assert len(retrieved.short_term_memory["conversation_history"]) == 2
|
||||
assert "auth.py" in retrieved.short_term_memory["recent_files"]
|
||||
|
||||
def test_update_memory(self, db_session):
|
||||
"""Test updating short-term memory."""
|
||||
project = Project(id=uuid.uuid4(), name="Update Memory Project", slug="update-memory-project-ai")
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Update Memory Agent",
|
||||
slug="update-memory-agent-ai",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
short_term_memory={"initial": "state"},
|
||||
)
|
||||
db_session.add(instance)
|
||||
db_session.commit()
|
||||
|
||||
# Update memory
|
||||
instance.short_term_memory = {"updated": "state", "new_key": "new_value"}
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentInstance).filter_by(id=instance.id).first()
|
||||
assert "initial" not in retrieved.short_term_memory
|
||||
assert retrieved.short_term_memory["updated"] == "state"
|
||||
assert retrieved.short_term_memory["new_key"] == "new_value"
|
||||
315
backend/tests/models/syndarix/test_agent_type.py
Normal file
315
backend/tests/models/syndarix/test_agent_type.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# tests/models/syndarix/test_agent_type.py
|
||||
"""
|
||||
Unit tests for the AgentType model.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.models.syndarix import AgentType
|
||||
|
||||
|
||||
class TestAgentTypeModel:
|
||||
"""Tests for AgentType model creation and fields."""
|
||||
|
||||
def test_create_agent_type_with_required_fields(self, db_session):
|
||||
"""Test creating an agent type with only required fields."""
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="You are a helpful assistant.",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentType).filter_by(slug="test-agent").first()
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "Test Agent"
|
||||
assert retrieved.slug == "test-agent"
|
||||
assert retrieved.personality_prompt == "You are a helpful assistant."
|
||||
assert retrieved.primary_model == "claude-opus-4-5-20251101"
|
||||
assert retrieved.is_active is True # Default
|
||||
assert retrieved.expertise == [] # Default empty list
|
||||
assert retrieved.fallback_models == [] # Default empty list
|
||||
assert retrieved.model_params == {} # Default empty dict
|
||||
assert retrieved.mcp_servers == [] # Default empty list
|
||||
assert retrieved.tool_permissions == {} # Default empty dict
|
||||
|
||||
def test_create_agent_type_with_all_fields(self, db_session):
|
||||
"""Test creating an agent type with all optional fields."""
|
||||
agent_type_id = uuid.uuid4()
|
||||
|
||||
agent_type = AgentType(
|
||||
id=agent_type_id,
|
||||
name="Full Agent Type",
|
||||
slug="full-agent-type",
|
||||
description="A fully configured agent type",
|
||||
expertise=["python", "fastapi", "testing"],
|
||||
personality_prompt="You are an expert Python developer...",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
fallback_models=["claude-sonnet-4-20250514", "gpt-4o"],
|
||||
model_params={"temperature": 0.7, "max_tokens": 4096},
|
||||
mcp_servers=["gitea", "file-system", "slack"],
|
||||
tool_permissions={"allowed": ["*"], "denied": ["dangerous_tool"]},
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentType).filter_by(id=agent_type_id).first()
|
||||
|
||||
assert retrieved.name == "Full Agent Type"
|
||||
assert retrieved.description == "A fully configured agent type"
|
||||
assert retrieved.expertise == ["python", "fastapi", "testing"]
|
||||
assert retrieved.fallback_models == ["claude-sonnet-4-20250514", "gpt-4o"]
|
||||
assert retrieved.model_params == {"temperature": 0.7, "max_tokens": 4096}
|
||||
assert retrieved.mcp_servers == ["gitea", "file-system", "slack"]
|
||||
assert retrieved.tool_permissions == {"allowed": ["*"], "denied": ["dangerous_tool"]}
|
||||
assert retrieved.is_active is True
|
||||
|
||||
def test_agent_type_unique_slug_constraint(self, db_session):
|
||||
"""Test that agent types cannot have duplicate slugs."""
|
||||
agent_type1 = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Agent One",
|
||||
slug="duplicate-agent-slug",
|
||||
personality_prompt="First agent",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(agent_type1)
|
||||
db_session.commit()
|
||||
|
||||
agent_type2 = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Agent Two",
|
||||
slug="duplicate-agent-slug", # Same slug
|
||||
personality_prompt="Second agent",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(agent_type2)
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
db_session.commit()
|
||||
|
||||
db_session.rollback()
|
||||
|
||||
def test_agent_type_timestamps(self, db_session):
|
||||
"""Test that timestamps are automatically set."""
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Timestamp Agent",
|
||||
slug="timestamp-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentType).filter_by(slug="timestamp-agent").first()
|
||||
|
||||
assert isinstance(retrieved.created_at, datetime)
|
||||
assert isinstance(retrieved.updated_at, datetime)
|
||||
|
||||
def test_agent_type_update(self, db_session):
|
||||
"""Test updating agent type fields."""
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Original Agent",
|
||||
slug="original-agent",
|
||||
personality_prompt="Original prompt",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
original_created_at = agent_type.created_at
|
||||
|
||||
# Update fields
|
||||
agent_type.name = "Updated Agent"
|
||||
agent_type.is_active = False
|
||||
agent_type.expertise = ["new", "skills"]
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentType).filter_by(slug="original-agent").first()
|
||||
|
||||
assert retrieved.name == "Updated Agent"
|
||||
assert retrieved.is_active is False
|
||||
assert retrieved.expertise == ["new", "skills"]
|
||||
assert retrieved.created_at == original_created_at
|
||||
assert retrieved.updated_at > original_created_at
|
||||
|
||||
def test_agent_type_delete(self, db_session):
|
||||
"""Test deleting an agent type."""
|
||||
agent_type_id = uuid.uuid4()
|
||||
agent_type = AgentType(
|
||||
id=agent_type_id,
|
||||
name="Delete Me",
|
||||
slug="delete-me-agent",
|
||||
personality_prompt="Delete test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
db_session.delete(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
deleted = db_session.query(AgentType).filter_by(id=agent_type_id).first()
|
||||
assert deleted is None
|
||||
|
||||
def test_agent_type_string_representation(self, db_session):
|
||||
"""Test the string representation of an agent type."""
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Repr Agent",
|
||||
slug="repr-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
assert str(agent_type) == "<AgentType Repr Agent (repr-agent) active=True>"
|
||||
assert repr(agent_type) == "<AgentType Repr Agent (repr-agent) active=True>"
|
||||
|
||||
|
||||
class TestAgentTypeJsonFields:
|
||||
"""Tests for AgentType JSON fields."""
|
||||
|
||||
def test_complex_expertise_list(self, db_session):
|
||||
"""Test storing a list of expertise areas."""
|
||||
expertise = ["python", "fastapi", "sqlalchemy", "postgresql", "redis", "docker"]
|
||||
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Expert Agent",
|
||||
slug="expert-agent",
|
||||
personality_prompt="Prompt",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
expertise=expertise,
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentType).filter_by(slug="expert-agent").first()
|
||||
assert retrieved.expertise == expertise
|
||||
assert "python" in retrieved.expertise
|
||||
assert len(retrieved.expertise) == 6
|
||||
|
||||
def test_complex_model_params(self, db_session):
|
||||
"""Test storing complex model parameters."""
|
||||
model_params = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4096,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.1,
|
||||
"presence_penalty": 0.1,
|
||||
"stop_sequences": ["###", "END"],
|
||||
}
|
||||
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Params Agent",
|
||||
slug="params-agent",
|
||||
personality_prompt="Prompt",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
model_params=model_params,
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentType).filter_by(slug="params-agent").first()
|
||||
assert retrieved.model_params == model_params
|
||||
assert retrieved.model_params["temperature"] == 0.7
|
||||
assert retrieved.model_params["stop_sequences"] == ["###", "END"]
|
||||
|
||||
def test_complex_tool_permissions(self, db_session):
|
||||
"""Test storing complex tool permissions."""
|
||||
tool_permissions = {
|
||||
"allowed": ["file:read", "file:write", "git:commit"],
|
||||
"denied": ["file:delete", "system:exec"],
|
||||
"require_approval": ["git:push", "gitea:create_pr"],
|
||||
"limits": {
|
||||
"file:write": {"max_size_mb": 10},
|
||||
"git:commit": {"require_message": True},
|
||||
},
|
||||
}
|
||||
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Permissions Agent",
|
||||
slug="permissions-agent",
|
||||
personality_prompt="Prompt",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
tool_permissions=tool_permissions,
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentType).filter_by(slug="permissions-agent").first()
|
||||
assert retrieved.tool_permissions == tool_permissions
|
||||
assert "file:read" in retrieved.tool_permissions["allowed"]
|
||||
assert retrieved.tool_permissions["limits"]["file:write"]["max_size_mb"] == 10
|
||||
|
||||
def test_empty_json_fields_default(self, db_session):
|
||||
"""Test that JSON fields default to empty structures."""
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Empty JSON Agent",
|
||||
slug="empty-json-agent",
|
||||
personality_prompt="Prompt",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentType).filter_by(slug="empty-json-agent").first()
|
||||
assert retrieved.expertise == []
|
||||
assert retrieved.fallback_models == []
|
||||
assert retrieved.model_params == {}
|
||||
assert retrieved.mcp_servers == []
|
||||
assert retrieved.tool_permissions == {}
|
||||
|
||||
|
||||
class TestAgentTypeIsActive:
|
||||
"""Tests for AgentType is_active field."""
|
||||
|
||||
def test_default_is_active(self, db_session):
|
||||
"""Test that is_active defaults to True."""
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Default Active",
|
||||
slug="default-active",
|
||||
personality_prompt="Prompt",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentType).filter_by(slug="default-active").first()
|
||||
assert retrieved.is_active is True
|
||||
|
||||
def test_deactivate_agent_type(self, db_session):
|
||||
"""Test deactivating an agent type."""
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Deactivate Me",
|
||||
slug="deactivate-me",
|
||||
personality_prompt="Prompt",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
is_active=True,
|
||||
)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
agent_type.is_active = False
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(AgentType).filter_by(slug="deactivate-me").first()
|
||||
assert retrieved.is_active is False
|
||||
463
backend/tests/models/syndarix/test_issue.py
Normal file
463
backend/tests/models/syndarix/test_issue.py
Normal file
@@ -0,0 +1,463 @@
|
||||
# tests/models/syndarix/test_issue.py
|
||||
"""
|
||||
Unit tests for the Issue model.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.syndarix import (
|
||||
AgentInstance,
|
||||
AgentType,
|
||||
Issue,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
Project,
|
||||
Sprint,
|
||||
SprintStatus,
|
||||
SyncStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestIssueModel:
|
||||
"""Tests for Issue model creation and fields."""
|
||||
|
||||
def test_create_issue_with_required_fields(self, db_session):
|
||||
"""Test creating an issue with only required fields."""
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Issue Project",
|
||||
slug="issue-project",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="Test Issue",
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(title="Test Issue").first()
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.title == "Test Issue"
|
||||
assert retrieved.body == "" # Default empty string
|
||||
assert retrieved.status == IssueStatus.OPEN # Default
|
||||
assert retrieved.priority == IssuePriority.MEDIUM # Default
|
||||
assert retrieved.labels == [] # Default empty list
|
||||
assert retrieved.story_points is None
|
||||
assert retrieved.assigned_agent_id is None
|
||||
assert retrieved.human_assignee is None
|
||||
assert retrieved.sprint_id is None
|
||||
assert retrieved.sync_status == SyncStatus.SYNCED # Default
|
||||
|
||||
def test_create_issue_with_all_fields(self, db_session):
|
||||
"""Test creating an issue with all optional fields."""
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Full Issue Project",
|
||||
slug="full-issue-project",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
issue_id = uuid.uuid4()
|
||||
now = datetime.now(UTC)
|
||||
|
||||
issue = Issue(
|
||||
id=issue_id,
|
||||
project_id=project.id,
|
||||
title="Full Issue",
|
||||
body="A complete issue with all fields set",
|
||||
status=IssueStatus.IN_PROGRESS,
|
||||
priority=IssuePriority.CRITICAL,
|
||||
labels=["bug", "security", "urgent"],
|
||||
story_points=8,
|
||||
human_assignee="john.doe@example.com",
|
||||
external_tracker="gitea",
|
||||
external_id="gitea-123",
|
||||
external_url="https://gitea.example.com/issues/123",
|
||||
external_number=123,
|
||||
sync_status=SyncStatus.SYNCED,
|
||||
last_synced_at=now,
|
||||
external_updated_at=now,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(id=issue_id).first()
|
||||
|
||||
assert retrieved.title == "Full Issue"
|
||||
assert retrieved.body == "A complete issue with all fields set"
|
||||
assert retrieved.status == IssueStatus.IN_PROGRESS
|
||||
assert retrieved.priority == IssuePriority.CRITICAL
|
||||
assert retrieved.labels == ["bug", "security", "urgent"]
|
||||
assert retrieved.story_points == 8
|
||||
assert retrieved.human_assignee == "john.doe@example.com"
|
||||
assert retrieved.external_tracker == "gitea"
|
||||
assert retrieved.external_id == "gitea-123"
|
||||
assert retrieved.external_number == 123
|
||||
assert retrieved.sync_status == SyncStatus.SYNCED
|
||||
|
||||
def test_issue_timestamps(self, db_session):
|
||||
"""Test that timestamps are automatically set."""
|
||||
project = Project(id=uuid.uuid4(), name="Timestamp Issue Project", slug="timestamp-issue-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="Timestamp Issue",
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
assert isinstance(issue.created_at, datetime)
|
||||
assert isinstance(issue.updated_at, datetime)
|
||||
|
||||
def test_issue_string_representation(self, db_session):
|
||||
"""Test the string representation of an issue."""
|
||||
project = Project(id=uuid.uuid4(), name="Repr Issue Project", slug="repr-issue-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="This is a very long issue title that should be truncated in repr",
|
||||
status=IssueStatus.OPEN,
|
||||
priority=IssuePriority.HIGH,
|
||||
)
|
||||
|
||||
repr_str = repr(issue)
|
||||
assert "This is a very long issue tit" in repr_str # First 30 chars
|
||||
assert "open" in repr_str
|
||||
assert "high" in repr_str
|
||||
|
||||
|
||||
class TestIssueStatus:
|
||||
"""Tests for Issue status field."""
|
||||
|
||||
def test_all_issue_statuses(self, db_session):
|
||||
"""Test that all issue statuses can be stored."""
|
||||
project = Project(id=uuid.uuid4(), name="Status Issue Project", slug="status-issue-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
for status in IssueStatus:
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title=f"Issue {status.value}",
|
||||
status=status,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(id=issue.id).first()
|
||||
assert retrieved.status == status
|
||||
|
||||
|
||||
class TestIssuePriority:
|
||||
"""Tests for Issue priority field."""
|
||||
|
||||
def test_all_issue_priorities(self, db_session):
|
||||
"""Test that all issue priorities can be stored."""
|
||||
project = Project(id=uuid.uuid4(), name="Priority Issue Project", slug="priority-issue-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
for priority in IssuePriority:
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title=f"Issue {priority.value}",
|
||||
priority=priority,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(id=issue.id).first()
|
||||
assert retrieved.priority == priority
|
||||
|
||||
|
||||
class TestIssueSyncStatus:
|
||||
"""Tests for Issue sync status field."""
|
||||
|
||||
def test_all_sync_statuses(self, db_session):
|
||||
"""Test that all sync statuses can be stored."""
|
||||
project = Project(id=uuid.uuid4(), name="Sync Issue Project", slug="sync-issue-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
for sync_status in SyncStatus:
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title=f"Issue {sync_status.value}",
|
||||
external_tracker="gitea",
|
||||
external_id=f"ext-{sync_status.value}",
|
||||
sync_status=sync_status,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(id=issue.id).first()
|
||||
assert retrieved.sync_status == sync_status
|
||||
|
||||
|
||||
class TestIssueLabels:
|
||||
"""Tests for Issue labels JSON field."""
|
||||
|
||||
def test_store_labels(self, db_session):
|
||||
"""Test storing labels list."""
|
||||
project = Project(id=uuid.uuid4(), name="Labels Issue Project", slug="labels-issue-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
labels = ["bug", "security", "high-priority", "needs-review"]
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="Issue with Labels",
|
||||
labels=labels,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(title="Issue with Labels").first()
|
||||
assert retrieved.labels == labels
|
||||
assert "security" in retrieved.labels
|
||||
|
||||
def test_update_labels(self, db_session):
|
||||
"""Test updating labels."""
|
||||
project = Project(id=uuid.uuid4(), name="Update Labels Project", slug="update-labels-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="Update Labels Issue",
|
||||
labels=["initial"],
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
issue.labels = ["updated", "new-label"]
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(title="Update Labels Issue").first()
|
||||
assert "initial" not in retrieved.labels
|
||||
assert "updated" in retrieved.labels
|
||||
|
||||
|
||||
class TestIssueAssignment:
|
||||
"""Tests for Issue assignment fields."""
|
||||
|
||||
def test_assign_to_agent(self, db_session):
|
||||
"""Test assigning an issue to an agent."""
|
||||
project = Project(id=uuid.uuid4(), name="Agent Assign Project", slug="agent-assign-project")
|
||||
agent_type = AgentType(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Agent Type",
|
||||
slug="test-agent-type-assign",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.add(agent_type)
|
||||
db_session.commit()
|
||||
|
||||
agent_instance = AgentInstance(
|
||||
id=uuid.uuid4(),
|
||||
agent_type_id=agent_type.id,
|
||||
project_id=project.id,
|
||||
)
|
||||
db_session.add(agent_instance)
|
||||
db_session.commit()
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="Agent Assignment Issue",
|
||||
assigned_agent_id=agent_instance.id,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(title="Agent Assignment Issue").first()
|
||||
assert retrieved.assigned_agent_id == agent_instance.id
|
||||
assert retrieved.human_assignee is None
|
||||
|
||||
def test_assign_to_human(self, db_session):
|
||||
"""Test assigning an issue to a human."""
|
||||
project = Project(id=uuid.uuid4(), name="Human Assign Project", slug="human-assign-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="Human Assignment Issue",
|
||||
human_assignee="developer@example.com",
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(title="Human Assignment Issue").first()
|
||||
assert retrieved.human_assignee == "developer@example.com"
|
||||
assert retrieved.assigned_agent_id is None
|
||||
|
||||
|
||||
class TestIssueSprintAssociation:
|
||||
"""Tests for Issue sprint association."""
|
||||
|
||||
def test_assign_issue_to_sprint(self, db_session):
|
||||
"""Test assigning an issue to a sprint."""
|
||||
project = Project(id=uuid.uuid4(), name="Sprint Assign Project", slug="sprint-assign-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
from datetime import date
|
||||
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Sprint 1",
|
||||
number=1,
|
||||
start_date=date.today(),
|
||||
end_date=date.today() + timedelta(days=14),
|
||||
status=SprintStatus.ACTIVE,
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="Sprint Issue",
|
||||
sprint_id=sprint.id,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(title="Sprint Issue").first()
|
||||
assert retrieved.sprint_id == sprint.id
|
||||
|
||||
|
||||
class TestIssueExternalTracker:
|
||||
"""Tests for Issue external tracker integration."""
|
||||
|
||||
def test_gitea_integration(self, db_session):
|
||||
"""Test Gitea external tracker fields."""
|
||||
project = Project(id=uuid.uuid4(), name="Gitea Project", slug="gitea-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
now = datetime.now(UTC)
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="Gitea Synced Issue",
|
||||
external_tracker="gitea",
|
||||
external_id="abc123xyz",
|
||||
external_url="https://gitea.example.com/org/repo/issues/42",
|
||||
external_number=42,
|
||||
sync_status=SyncStatus.SYNCED,
|
||||
last_synced_at=now,
|
||||
external_updated_at=now,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(title="Gitea Synced Issue").first()
|
||||
assert retrieved.external_tracker == "gitea"
|
||||
assert retrieved.external_id == "abc123xyz"
|
||||
assert retrieved.external_number == 42
|
||||
assert "/issues/42" in retrieved.external_url
|
||||
|
||||
def test_github_integration(self, db_session):
|
||||
"""Test GitHub external tracker fields."""
|
||||
project = Project(id=uuid.uuid4(), name="GitHub Project", slug="github-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="GitHub Synced Issue",
|
||||
external_tracker="github",
|
||||
external_id="gh-12345",
|
||||
external_url="https://github.com/org/repo/issues/100",
|
||||
external_number=100,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(title="GitHub Synced Issue").first()
|
||||
assert retrieved.external_tracker == "github"
|
||||
assert retrieved.external_number == 100
|
||||
|
||||
|
||||
class TestIssueLifecycle:
|
||||
"""Tests for Issue lifecycle operations."""
|
||||
|
||||
def test_close_issue(self, db_session):
|
||||
"""Test closing an issue."""
|
||||
project = Project(id=uuid.uuid4(), name="Close Issue Project", slug="close-issue-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="Issue to Close",
|
||||
status=IssueStatus.OPEN,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
# Close the issue
|
||||
now = datetime.now(UTC)
|
||||
issue.status = IssueStatus.CLOSED
|
||||
issue.closed_at = now
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(title="Issue to Close").first()
|
||||
assert retrieved.status == IssueStatus.CLOSED
|
||||
assert retrieved.closed_at is not None
|
||||
|
||||
def test_reopen_issue(self, db_session):
|
||||
"""Test reopening a closed issue."""
|
||||
project = Project(id=uuid.uuid4(), name="Reopen Issue Project", slug="reopen-issue-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
now = datetime.now(UTC)
|
||||
issue = Issue(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
title="Issue to Reopen",
|
||||
status=IssueStatus.CLOSED,
|
||||
closed_at=now,
|
||||
)
|
||||
db_session.add(issue)
|
||||
db_session.commit()
|
||||
|
||||
# Reopen the issue
|
||||
issue.status = IssueStatus.OPEN
|
||||
issue.closed_at = None
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Issue).filter_by(title="Issue to Reopen").first()
|
||||
assert retrieved.status == IssueStatus.OPEN
|
||||
assert retrieved.closed_at is None
|
||||
262
backend/tests/models/syndarix/test_project.py
Normal file
262
backend/tests/models/syndarix/test_project.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# tests/models/syndarix/test_project.py
|
||||
"""
|
||||
Unit tests for the Project model.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.models.syndarix import (
|
||||
AutonomyLevel,
|
||||
Project,
|
||||
ProjectStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestProjectModel:
|
||||
"""Tests for Project model creation and fields."""
|
||||
|
||||
def test_create_project_with_required_fields(self, db_session):
|
||||
"""Test creating a project with only required fields."""
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Project",
|
||||
slug="test-project",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Project).filter_by(slug="test-project").first()
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "Test Project"
|
||||
assert retrieved.slug == "test-project"
|
||||
assert retrieved.autonomy_level == AutonomyLevel.MILESTONE # Default
|
||||
assert retrieved.status == ProjectStatus.ACTIVE # Default
|
||||
assert retrieved.settings == {} # Default empty dict
|
||||
assert retrieved.description is None
|
||||
assert retrieved.owner_id is None
|
||||
|
||||
def test_create_project_with_all_fields(self, db_session):
|
||||
"""Test creating a project with all optional fields."""
|
||||
project_id = uuid.uuid4()
|
||||
owner_id = uuid.uuid4()
|
||||
|
||||
project = Project(
|
||||
id=project_id,
|
||||
name="Full Project",
|
||||
slug="full-project",
|
||||
description="A complete project with all fields",
|
||||
autonomy_level=AutonomyLevel.AUTONOMOUS,
|
||||
status=ProjectStatus.PAUSED,
|
||||
settings={"webhook_url": "https://example.com/webhook"},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Project).filter_by(id=project_id).first()
|
||||
|
||||
assert retrieved.name == "Full Project"
|
||||
assert retrieved.slug == "full-project"
|
||||
assert retrieved.description == "A complete project with all fields"
|
||||
assert retrieved.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||
assert retrieved.status == ProjectStatus.PAUSED
|
||||
assert retrieved.settings == {"webhook_url": "https://example.com/webhook"}
|
||||
assert retrieved.owner_id == owner_id
|
||||
|
||||
def test_project_unique_slug_constraint(self, db_session):
|
||||
"""Test that projects cannot have duplicate slugs."""
|
||||
project1 = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Project One",
|
||||
slug="duplicate-slug",
|
||||
)
|
||||
db_session.add(project1)
|
||||
db_session.commit()
|
||||
|
||||
project2 = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Project Two",
|
||||
slug="duplicate-slug", # Same slug
|
||||
)
|
||||
db_session.add(project2)
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
db_session.commit()
|
||||
|
||||
db_session.rollback()
|
||||
|
||||
def test_project_timestamps(self, db_session):
|
||||
"""Test that timestamps are automatically set."""
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Timestamp Project",
|
||||
slug="timestamp-project",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Project).filter_by(slug="timestamp-project").first()
|
||||
|
||||
assert isinstance(retrieved.created_at, datetime)
|
||||
assert isinstance(retrieved.updated_at, datetime)
|
||||
|
||||
def test_project_update(self, db_session):
|
||||
"""Test updating project fields."""
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Original Name",
|
||||
slug="original-slug",
|
||||
status=ProjectStatus.ACTIVE,
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
original_created_at = project.created_at
|
||||
|
||||
# Update fields
|
||||
project.name = "Updated Name"
|
||||
project.status = ProjectStatus.COMPLETED
|
||||
project.settings = {"new_setting": "value"}
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Project).filter_by(slug="original-slug").first()
|
||||
|
||||
assert retrieved.name == "Updated Name"
|
||||
assert retrieved.status == ProjectStatus.COMPLETED
|
||||
assert retrieved.settings == {"new_setting": "value"}
|
||||
assert retrieved.created_at == original_created_at
|
||||
assert retrieved.updated_at > original_created_at
|
||||
|
||||
def test_project_delete(self, db_session):
|
||||
"""Test deleting a project."""
|
||||
project_id = uuid.uuid4()
|
||||
project = Project(
|
||||
id=project_id,
|
||||
name="Delete Me",
|
||||
slug="delete-me",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
db_session.delete(project)
|
||||
db_session.commit()
|
||||
|
||||
deleted = db_session.query(Project).filter_by(id=project_id).first()
|
||||
assert deleted is None
|
||||
|
||||
def test_project_string_representation(self, db_session):
|
||||
"""Test the string representation of a project."""
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Repr Project",
|
||||
slug="repr-project",
|
||||
status=ProjectStatus.ACTIVE,
|
||||
)
|
||||
|
||||
assert str(project) == "<Project Repr Project (repr-project) status=active>"
|
||||
assert repr(project) == "<Project Repr Project (repr-project) status=active>"
|
||||
|
||||
|
||||
class TestProjectEnums:
|
||||
"""Tests for Project enum fields."""
|
||||
|
||||
def test_all_autonomy_levels(self, db_session):
|
||||
"""Test that all autonomy levels can be stored."""
|
||||
for level in AutonomyLevel:
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name=f"Project {level.value}",
|
||||
slug=f"project-{level.value}",
|
||||
autonomy_level=level,
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Project).filter_by(slug=f"project-{level.value}").first()
|
||||
assert retrieved.autonomy_level == level
|
||||
|
||||
def test_all_project_statuses(self, db_session):
|
||||
"""Test that all project statuses can be stored."""
|
||||
for status in ProjectStatus:
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name=f"Project {status.value}",
|
||||
slug=f"project-status-{status.value}",
|
||||
status=status,
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Project).filter_by(slug=f"project-status-{status.value}").first()
|
||||
assert retrieved.status == status
|
||||
|
||||
|
||||
class TestProjectSettings:
|
||||
"""Tests for Project JSON settings field."""
|
||||
|
||||
def test_complex_json_settings(self, db_session):
|
||||
"""Test storing complex JSON in settings."""
|
||||
complex_settings = {
|
||||
"mcp_servers": ["gitea", "slack", "file-system"],
|
||||
"webhook_urls": {
|
||||
"on_issue_created": "https://example.com/issue",
|
||||
"on_sprint_completed": "https://example.com/sprint",
|
||||
},
|
||||
"notification_settings": {
|
||||
"email": True,
|
||||
"slack_channel": "#syndarix-updates",
|
||||
},
|
||||
"tags": ["important", "client-a"],
|
||||
}
|
||||
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Complex Settings Project",
|
||||
slug="complex-settings",
|
||||
settings=complex_settings,
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Project).filter_by(slug="complex-settings").first()
|
||||
|
||||
assert retrieved.settings == complex_settings
|
||||
assert retrieved.settings["mcp_servers"] == ["gitea", "slack", "file-system"]
|
||||
assert retrieved.settings["webhook_urls"]["on_issue_created"] == "https://example.com/issue"
|
||||
assert "important" in retrieved.settings["tags"]
|
||||
|
||||
def test_empty_settings(self, db_session):
|
||||
"""Test that empty settings defaults correctly."""
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Empty Settings",
|
||||
slug="empty-settings",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Project).filter_by(slug="empty-settings").first()
|
||||
assert retrieved.settings == {}
|
||||
|
||||
def test_update_settings(self, db_session):
|
||||
"""Test updating settings field."""
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Update Settings",
|
||||
slug="update-settings",
|
||||
settings={"initial": "value"},
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
# Update settings
|
||||
project.settings = {"updated": "new_value", "additional": "data"}
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Project).filter_by(slug="update-settings").first()
|
||||
assert retrieved.settings == {"updated": "new_value", "additional": "data"}
|
||||
507
backend/tests/models/syndarix/test_sprint.py
Normal file
507
backend/tests/models/syndarix/test_sprint.py
Normal file
@@ -0,0 +1,507 @@
|
||||
# tests/models/syndarix/test_sprint.py
|
||||
"""
|
||||
Unit tests for the Sprint model.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import date, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.syndarix import (
|
||||
Project,
|
||||
Sprint,
|
||||
SprintStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestSprintModel:
|
||||
"""Tests for Sprint model creation and fields."""
|
||||
|
||||
def test_create_sprint_with_required_fields(self, db_session):
|
||||
"""Test creating a sprint with only required fields."""
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Sprint Project",
|
||||
slug="sprint-project",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Sprint 1",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Sprint 1").first()
|
||||
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "Sprint 1"
|
||||
assert retrieved.number == 1
|
||||
assert retrieved.start_date == today
|
||||
assert retrieved.end_date == today + timedelta(days=14)
|
||||
assert retrieved.status == SprintStatus.PLANNED # Default
|
||||
assert retrieved.goal is None
|
||||
assert retrieved.planned_points is None
|
||||
assert retrieved.completed_points is None
|
||||
|
||||
def test_create_sprint_with_all_fields(self, db_session):
|
||||
"""Test creating a sprint with all optional fields."""
|
||||
project = Project(
|
||||
id=uuid.uuid4(),
|
||||
name="Full Sprint Project",
|
||||
slug="full-sprint-project",
|
||||
)
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint_id = uuid.uuid4()
|
||||
|
||||
sprint = Sprint(
|
||||
id=sprint_id,
|
||||
project_id=project.id,
|
||||
name="Full Sprint",
|
||||
number=5,
|
||||
goal="Complete all authentication features",
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.ACTIVE,
|
||||
planned_points=34,
|
||||
completed_points=21,
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(id=sprint_id).first()
|
||||
|
||||
assert retrieved.name == "Full Sprint"
|
||||
assert retrieved.number == 5
|
||||
assert retrieved.goal == "Complete all authentication features"
|
||||
assert retrieved.status == SprintStatus.ACTIVE
|
||||
assert retrieved.planned_points == 34
|
||||
assert retrieved.completed_points == 21
|
||||
|
||||
def test_sprint_timestamps(self, db_session):
|
||||
"""Test that timestamps are automatically set."""
|
||||
project = Project(id=uuid.uuid4(), name="Timestamp Sprint Project", slug="timestamp-sprint-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Timestamp Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
assert isinstance(sprint.created_at, datetime)
|
||||
assert isinstance(sprint.updated_at, datetime)
|
||||
|
||||
def test_sprint_string_representation(self, db_session):
|
||||
"""Test the string representation of a sprint."""
|
||||
project = Project(id=uuid.uuid4(), name="Repr Sprint Project", slug="repr-sprint-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Sprint Alpha",
|
||||
number=3,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.ACTIVE,
|
||||
)
|
||||
|
||||
repr_str = repr(sprint)
|
||||
assert "Sprint Alpha" in repr_str
|
||||
assert "#3" in repr_str
|
||||
assert str(project.id) in repr_str
|
||||
assert "active" in repr_str
|
||||
|
||||
|
||||
class TestSprintStatus:
|
||||
"""Tests for Sprint status field."""
|
||||
|
||||
def test_all_sprint_statuses(self, db_session):
|
||||
"""Test that all sprint statuses can be stored."""
|
||||
project = Project(id=uuid.uuid4(), name="Status Sprint Project", slug="status-sprint-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
for idx, status in enumerate(SprintStatus):
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name=f"Sprint {status.value}",
|
||||
number=idx + 1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=status,
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(id=sprint.id).first()
|
||||
assert retrieved.status == status
|
||||
|
||||
|
||||
class TestSprintLifecycle:
|
||||
"""Tests for Sprint lifecycle operations."""
|
||||
|
||||
def test_start_sprint(self, db_session):
|
||||
"""Test starting a planned sprint."""
|
||||
project = Project(id=uuid.uuid4(), name="Start Sprint Project", slug="start-sprint-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Sprint to Start",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.PLANNED,
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
# Start the sprint
|
||||
sprint.status = SprintStatus.ACTIVE
|
||||
sprint.planned_points = 21
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Sprint to Start").first()
|
||||
assert retrieved.status == SprintStatus.ACTIVE
|
||||
assert retrieved.planned_points == 21
|
||||
|
||||
def test_complete_sprint(self, db_session):
|
||||
"""Test completing an active sprint."""
|
||||
project = Project(id=uuid.uuid4(), name="Complete Sprint Project", slug="complete-sprint-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Sprint to Complete",
|
||||
number=1,
|
||||
start_date=today - timedelta(days=14),
|
||||
end_date=today,
|
||||
status=SprintStatus.ACTIVE,
|
||||
planned_points=21,
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
# Complete the sprint
|
||||
sprint.status = SprintStatus.COMPLETED
|
||||
sprint.completed_points = 18
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Sprint to Complete").first()
|
||||
assert retrieved.status == SprintStatus.COMPLETED
|
||||
assert retrieved.completed_points == 18
|
||||
|
||||
def test_cancel_sprint(self, db_session):
|
||||
"""Test cancelling a sprint."""
|
||||
project = Project(id=uuid.uuid4(), name="Cancel Sprint Project", slug="cancel-sprint-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Sprint to Cancel",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.ACTIVE,
|
||||
planned_points=21,
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
# Cancel the sprint
|
||||
sprint.status = SprintStatus.CANCELLED
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Sprint to Cancel").first()
|
||||
assert retrieved.status == SprintStatus.CANCELLED
|
||||
|
||||
|
||||
class TestSprintDates:
|
||||
"""Tests for Sprint date fields."""
|
||||
|
||||
def test_sprint_date_range(self, db_session):
|
||||
"""Test storing sprint date range."""
|
||||
project = Project(id=uuid.uuid4(), name="Date Range Project", slug="date-range-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
start = date(2024, 1, 1)
|
||||
end = date(2024, 1, 14)
|
||||
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Date Range Sprint",
|
||||
number=1,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Date Range Sprint").first()
|
||||
assert retrieved.start_date == start
|
||||
assert retrieved.end_date == end
|
||||
|
||||
def test_one_day_sprint(self, db_session):
|
||||
"""Test creating a one-day sprint."""
|
||||
project = Project(id=uuid.uuid4(), name="One Day Project", slug="one-day-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="One Day Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today, # Same day
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="One Day Sprint").first()
|
||||
assert retrieved.start_date == retrieved.end_date
|
||||
|
||||
def test_long_sprint(self, db_session):
|
||||
"""Test creating a long sprint (e.g., 4 weeks)."""
|
||||
project = Project(id=uuid.uuid4(), name="Long Sprint Project", slug="long-sprint-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Long Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=28), # 4 weeks
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Long Sprint").first()
|
||||
delta = retrieved.end_date - retrieved.start_date
|
||||
assert delta.days == 28
|
||||
|
||||
|
||||
class TestSprintPoints:
|
||||
"""Tests for Sprint story points fields."""
|
||||
|
||||
def test_sprint_with_zero_points(self, db_session):
|
||||
"""Test sprint with zero planned points."""
|
||||
project = Project(id=uuid.uuid4(), name="Zero Points Project", slug="zero-points-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Zero Points Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
planned_points=0,
|
||||
completed_points=0,
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Zero Points Sprint").first()
|
||||
assert retrieved.planned_points == 0
|
||||
assert retrieved.completed_points == 0
|
||||
|
||||
def test_sprint_velocity_calculation(self, db_session):
|
||||
"""Test that we can calculate velocity from points."""
|
||||
project = Project(id=uuid.uuid4(), name="Velocity Project", slug="velocity-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Velocity Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.COMPLETED,
|
||||
planned_points=21,
|
||||
completed_points=18,
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Velocity Sprint").first()
|
||||
|
||||
# Calculate velocity
|
||||
velocity = retrieved.completed_points / retrieved.planned_points
|
||||
assert velocity == pytest.approx(18 / 21, rel=0.01)
|
||||
|
||||
def test_sprint_overdelivery(self, db_session):
|
||||
"""Test sprint where completed > planned (stretch goals)."""
|
||||
project = Project(id=uuid.uuid4(), name="Overdelivery Project", slug="overdelivery-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Overdelivery Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.COMPLETED,
|
||||
planned_points=20,
|
||||
completed_points=25, # Completed more than planned
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Overdelivery Sprint").first()
|
||||
assert retrieved.completed_points > retrieved.planned_points
|
||||
|
||||
|
||||
class TestSprintNumber:
|
||||
"""Tests for Sprint number field."""
|
||||
|
||||
def test_sequential_sprint_numbers(self, db_session):
|
||||
"""Test creating sprints with sequential numbers."""
|
||||
project = Project(id=uuid.uuid4(), name="Sequential Project", slug="sequential-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
for i in range(1, 6):
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name=f"Sprint {i}",
|
||||
number=i,
|
||||
start_date=today + timedelta(days=(i - 1) * 14),
|
||||
end_date=today + timedelta(days=i * 14 - 1),
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
sprints = db_session.query(Sprint).filter_by(project_id=project.id).order_by(Sprint.number).all()
|
||||
assert len(sprints) == 5
|
||||
for i, sprint in enumerate(sprints, 1):
|
||||
assert sprint.number == i
|
||||
|
||||
def test_large_sprint_number(self, db_session):
|
||||
"""Test sprint with large number (e.g., long-running project)."""
|
||||
project = Project(id=uuid.uuid4(), name="Large Number Project", slug="large-number-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Sprint 100",
|
||||
number=100,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Sprint 100").first()
|
||||
assert retrieved.number == 100
|
||||
|
||||
|
||||
class TestSprintUpdate:
|
||||
"""Tests for Sprint update operations."""
|
||||
|
||||
def test_update_sprint_goal(self, db_session):
|
||||
"""Test updating sprint goal."""
|
||||
project = Project(id=uuid.uuid4(), name="Update Goal Project", slug="update-goal-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Update Goal Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
goal="Original goal",
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
original_created_at = sprint.created_at
|
||||
|
||||
sprint.goal = "Updated goal with more detail"
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Update Goal Sprint").first()
|
||||
assert retrieved.goal == "Updated goal with more detail"
|
||||
assert retrieved.created_at == original_created_at
|
||||
assert retrieved.updated_at > original_created_at
|
||||
|
||||
def test_update_sprint_dates(self, db_session):
|
||||
"""Test updating sprint dates."""
|
||||
project = Project(id=uuid.uuid4(), name="Update Dates Project", slug="update-dates-project")
|
||||
db_session.add(project)
|
||||
db_session.commit()
|
||||
|
||||
today = date.today()
|
||||
sprint = Sprint(
|
||||
id=uuid.uuid4(),
|
||||
project_id=project.id,
|
||||
name="Update Dates Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
db_session.add(sprint)
|
||||
db_session.commit()
|
||||
|
||||
# Extend sprint by a week
|
||||
sprint.end_date = today + timedelta(days=21)
|
||||
db_session.commit()
|
||||
|
||||
retrieved = db_session.query(Sprint).filter_by(name="Update Dates Sprint").first()
|
||||
delta = retrieved.end_date - retrieved.start_date
|
||||
assert delta.days == 21
|
||||
2
backend/tests/schemas/syndarix/__init__.py
Normal file
2
backend/tests/schemas/syndarix/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# tests/schemas/syndarix/__init__.py
|
||||
"""Syndarix schema validation tests."""
|
||||
68
backend/tests/schemas/syndarix/conftest.py
Normal file
68
backend/tests/schemas/syndarix/conftest.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# tests/schemas/syndarix/conftest.py
|
||||
"""
|
||||
Shared fixtures for Syndarix schema tests.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_uuid():
|
||||
"""Return a valid UUID for testing."""
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_project_data():
|
||||
"""Return valid project data for schema testing."""
|
||||
return {
|
||||
"name": "Test Project",
|
||||
"slug": "test-project",
|
||||
"description": "A test project",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_agent_type_data():
|
||||
"""Return valid agent type data for schema testing."""
|
||||
return {
|
||||
"name": "Backend Engineer",
|
||||
"slug": "backend-engineer",
|
||||
"personality_prompt": "You are an expert backend engineer.",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_sprint_data(valid_uuid):
|
||||
"""Return valid sprint data for schema testing."""
|
||||
today = date.today()
|
||||
return {
|
||||
"project_id": valid_uuid,
|
||||
"name": "Sprint 1",
|
||||
"number": 1,
|
||||
"start_date": today,
|
||||
"end_date": today + timedelta(days=14),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_issue_data(valid_uuid):
|
||||
"""Return valid issue data for schema testing."""
|
||||
return {
|
||||
"project_id": valid_uuid,
|
||||
"title": "Test Issue",
|
||||
"body": "Issue description",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_agent_instance_data(valid_uuid):
|
||||
"""Return valid agent instance data for schema testing."""
|
||||
return {
|
||||
"agent_type_id": valid_uuid,
|
||||
"project_id": valid_uuid,
|
||||
}
|
||||
244
backend/tests/schemas/syndarix/test_agent_instance_schemas.py
Normal file
244
backend/tests/schemas/syndarix/test_agent_instance_schemas.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# tests/schemas/syndarix/test_agent_instance_schemas.py
|
||||
"""
|
||||
Tests for AgentInstance schema validation.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.syndarix import (
|
||||
AgentInstanceCreate,
|
||||
AgentInstanceUpdate,
|
||||
AgentStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentInstanceCreateValidation:
|
||||
"""Tests for AgentInstanceCreate schema validation."""
|
||||
|
||||
def test_valid_agent_instance_create(self, valid_agent_instance_data):
|
||||
"""Test creating agent instance with valid data."""
|
||||
instance = AgentInstanceCreate(**valid_agent_instance_data)
|
||||
|
||||
assert instance.agent_type_id is not None
|
||||
assert instance.project_id is not None
|
||||
|
||||
def test_agent_instance_create_defaults(self, valid_agent_instance_data):
|
||||
"""Test that defaults are applied correctly."""
|
||||
instance = AgentInstanceCreate(**valid_agent_instance_data)
|
||||
|
||||
assert instance.status == AgentStatus.IDLE
|
||||
assert instance.current_task is None
|
||||
assert instance.short_term_memory == {}
|
||||
assert instance.long_term_memory_ref is None
|
||||
assert instance.session_id is None
|
||||
|
||||
def test_agent_instance_create_with_all_fields(self, valid_uuid):
|
||||
"""Test creating agent instance with all optional fields."""
|
||||
instance = AgentInstanceCreate(
|
||||
agent_type_id=valid_uuid,
|
||||
project_id=valid_uuid,
|
||||
status=AgentStatus.WORKING,
|
||||
current_task="Processing feature request",
|
||||
short_term_memory={"context": "working"},
|
||||
long_term_memory_ref="project-123/agent-456",
|
||||
session_id="session-abc",
|
||||
)
|
||||
|
||||
assert instance.status == AgentStatus.WORKING
|
||||
assert instance.current_task == "Processing feature request"
|
||||
assert instance.short_term_memory == {"context": "working"}
|
||||
assert instance.long_term_memory_ref == "project-123/agent-456"
|
||||
assert instance.session_id == "session-abc"
|
||||
|
||||
def test_agent_instance_create_agent_type_id_required(self, valid_uuid):
|
||||
"""Test that agent_type_id is required."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentInstanceCreate(
|
||||
project_id=valid_uuid,
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("agent_type_id" in str(e).lower() for e in errors)
|
||||
|
||||
def test_agent_instance_create_project_id_required(self, valid_uuid):
|
||||
"""Test that project_id is required."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentInstanceCreate(
|
||||
agent_type_id=valid_uuid,
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("project_id" in str(e).lower() for e in errors)
|
||||
|
||||
|
||||
class TestAgentInstanceUpdateValidation:
|
||||
"""Tests for AgentInstanceUpdate schema validation."""
|
||||
|
||||
def test_agent_instance_update_partial(self):
|
||||
"""Test updating only some fields."""
|
||||
update = AgentInstanceUpdate(
|
||||
status=AgentStatus.WORKING,
|
||||
)
|
||||
|
||||
assert update.status == AgentStatus.WORKING
|
||||
assert update.current_task is None
|
||||
assert update.short_term_memory is None
|
||||
|
||||
def test_agent_instance_update_all_fields(self):
|
||||
"""Test updating all fields."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
now = datetime.now(UTC)
|
||||
update = AgentInstanceUpdate(
|
||||
status=AgentStatus.WORKING,
|
||||
current_task="New task",
|
||||
short_term_memory={"new": "context"},
|
||||
long_term_memory_ref="new-ref",
|
||||
session_id="new-session",
|
||||
last_activity_at=now,
|
||||
tasks_completed=5,
|
||||
tokens_used=10000,
|
||||
cost_incurred=Decimal("1.5000"),
|
||||
)
|
||||
|
||||
assert update.status == AgentStatus.WORKING
|
||||
assert update.current_task == "New task"
|
||||
assert update.tasks_completed == 5
|
||||
assert update.tokens_used == 10000
|
||||
assert update.cost_incurred == Decimal("1.5000")
|
||||
|
||||
def test_agent_instance_update_tasks_completed_negative_fails(self):
|
||||
"""Test that negative tasks_completed raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentInstanceUpdate(tasks_completed=-1)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("tasks_completed" in str(e).lower() for e in errors)
|
||||
|
||||
def test_agent_instance_update_tokens_used_negative_fails(self):
|
||||
"""Test that negative tokens_used raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentInstanceUpdate(tokens_used=-1)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("tokens_used" in str(e).lower() for e in errors)
|
||||
|
||||
def test_agent_instance_update_cost_incurred_negative_fails(self):
|
||||
"""Test that negative cost_incurred raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentInstanceUpdate(cost_incurred=Decimal("-0.01"))
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("cost_incurred" in str(e).lower() for e in errors)
|
||||
|
||||
|
||||
class TestAgentStatusEnum:
|
||||
"""Tests for AgentStatus enum validation."""
|
||||
|
||||
def test_valid_agent_statuses(self, valid_uuid):
|
||||
"""Test all valid agent statuses."""
|
||||
for status in AgentStatus:
|
||||
instance = AgentInstanceCreate(
|
||||
agent_type_id=valid_uuid,
|
||||
project_id=valid_uuid,
|
||||
status=status,
|
||||
)
|
||||
assert instance.status == status
|
||||
|
||||
def test_invalid_agent_status(self, valid_uuid):
|
||||
"""Test that invalid agent status raises ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentInstanceCreate(
|
||||
agent_type_id=valid_uuid,
|
||||
project_id=valid_uuid,
|
||||
status="invalid", # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class TestAgentInstanceShortTermMemory:
|
||||
"""Tests for AgentInstance short_term_memory validation."""
|
||||
|
||||
def test_short_term_memory_empty_dict(self, valid_uuid):
|
||||
"""Test that empty short_term_memory is valid."""
|
||||
instance = AgentInstanceCreate(
|
||||
agent_type_id=valid_uuid,
|
||||
project_id=valid_uuid,
|
||||
short_term_memory={},
|
||||
)
|
||||
assert instance.short_term_memory == {}
|
||||
|
||||
def test_short_term_memory_complex(self, valid_uuid):
|
||||
"""Test complex short_term_memory structure."""
|
||||
memory = {
|
||||
"conversation_history": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there"},
|
||||
],
|
||||
"recent_files": ["file1.py", "file2.py"],
|
||||
"decisions": {"key": "value"},
|
||||
"context_tokens": 1024,
|
||||
}
|
||||
instance = AgentInstanceCreate(
|
||||
agent_type_id=valid_uuid,
|
||||
project_id=valid_uuid,
|
||||
short_term_memory=memory,
|
||||
)
|
||||
assert instance.short_term_memory == memory
|
||||
|
||||
|
||||
class TestAgentInstanceStringFields:
|
||||
"""Tests for AgentInstance string field validation."""
|
||||
|
||||
def test_long_term_memory_ref_max_length(self, valid_uuid):
|
||||
"""Test long_term_memory_ref max length."""
|
||||
long_ref = "a" * 500 # Max length is 500
|
||||
|
||||
instance = AgentInstanceCreate(
|
||||
agent_type_id=valid_uuid,
|
||||
project_id=valid_uuid,
|
||||
long_term_memory_ref=long_ref,
|
||||
)
|
||||
assert instance.long_term_memory_ref == long_ref
|
||||
|
||||
def test_long_term_memory_ref_too_long(self, valid_uuid):
|
||||
"""Test that too long long_term_memory_ref raises ValidationError."""
|
||||
too_long = "a" * 501
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentInstanceCreate(
|
||||
agent_type_id=valid_uuid,
|
||||
project_id=valid_uuid,
|
||||
long_term_memory_ref=too_long,
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("long_term_memory_ref" in str(e).lower() for e in errors)
|
||||
|
||||
def test_session_id_max_length(self, valid_uuid):
|
||||
"""Test session_id max length."""
|
||||
long_session = "a" * 255 # Max length is 255
|
||||
|
||||
instance = AgentInstanceCreate(
|
||||
agent_type_id=valid_uuid,
|
||||
project_id=valid_uuid,
|
||||
session_id=long_session,
|
||||
)
|
||||
assert instance.session_id == long_session
|
||||
|
||||
def test_session_id_too_long(self, valid_uuid):
|
||||
"""Test that too long session_id raises ValidationError."""
|
||||
too_long = "a" * 256
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentInstanceCreate(
|
||||
agent_type_id=valid_uuid,
|
||||
project_id=valid_uuid,
|
||||
session_id=too_long,
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("session_id" in str(e).lower() for e in errors)
|
||||
318
backend/tests/schemas/syndarix/test_agent_type_schemas.py
Normal file
318
backend/tests/schemas/syndarix/test_agent_type_schemas.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# tests/schemas/syndarix/test_agent_type_schemas.py
|
||||
"""
|
||||
Tests for AgentType schema validation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.syndarix import (
|
||||
AgentTypeCreate,
|
||||
AgentTypeUpdate,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentTypeCreateValidation:
|
||||
"""Tests for AgentTypeCreate schema validation."""
|
||||
|
||||
def test_valid_agent_type_create(self, valid_agent_type_data):
|
||||
"""Test creating agent type with valid data."""
|
||||
agent_type = AgentTypeCreate(**valid_agent_type_data)
|
||||
|
||||
assert agent_type.name == "Backend Engineer"
|
||||
assert agent_type.slug == "backend-engineer"
|
||||
assert agent_type.personality_prompt == "You are an expert backend engineer."
|
||||
assert agent_type.primary_model == "claude-opus-4-5-20251101"
|
||||
|
||||
def test_agent_type_create_defaults(self, valid_agent_type_data):
|
||||
"""Test that defaults are applied correctly."""
|
||||
agent_type = AgentTypeCreate(**valid_agent_type_data)
|
||||
|
||||
assert agent_type.expertise == []
|
||||
assert agent_type.fallback_models == []
|
||||
assert agent_type.model_params == {}
|
||||
assert agent_type.mcp_servers == []
|
||||
assert agent_type.tool_permissions == {}
|
||||
assert agent_type.is_active is True
|
||||
|
||||
def test_agent_type_create_with_all_fields(self, valid_agent_type_data):
|
||||
"""Test creating agent type with all optional fields."""
|
||||
agent_type = AgentTypeCreate(
|
||||
**valid_agent_type_data,
|
||||
description="Detailed description",
|
||||
expertise=["python", "fastapi"],
|
||||
fallback_models=["claude-sonnet-4-20250514"],
|
||||
model_params={"temperature": 0.7},
|
||||
mcp_servers=["gitea", "slack"],
|
||||
tool_permissions={"allowed": ["*"]},
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
assert agent_type.description == "Detailed description"
|
||||
assert agent_type.expertise == ["python", "fastapi"]
|
||||
assert agent_type.fallback_models == ["claude-sonnet-4-20250514"]
|
||||
|
||||
def test_agent_type_create_name_empty_fails(self):
|
||||
"""Test that empty name raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentTypeCreate(
|
||||
name="",
|
||||
slug="valid-slug",
|
||||
personality_prompt="Test prompt",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name" in str(e) for e in errors)
|
||||
|
||||
def test_agent_type_create_name_stripped(self):
|
||||
"""Test that name is stripped of whitespace."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name=" Padded Name ",
|
||||
slug="padded-slug",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
|
||||
assert agent_type.name == "Padded Name"
|
||||
|
||||
def test_agent_type_create_personality_prompt_required(self):
|
||||
"""Test that personality_prompt is required."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("personality_prompt" in str(e).lower() for e in errors)
|
||||
|
||||
def test_agent_type_create_primary_model_required(self):
|
||||
"""Test that primary_model is required."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test prompt",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("primary_model" in str(e).lower() for e in errors)
|
||||
|
||||
|
||||
class TestAgentTypeSlugValidation:
|
||||
"""Tests for AgentType slug validation."""
|
||||
|
||||
def test_valid_slugs(self):
|
||||
"""Test various valid slug formats."""
|
||||
valid_slugs = [
|
||||
"simple",
|
||||
"with-hyphens",
|
||||
"has123numbers",
|
||||
]
|
||||
|
||||
for slug in valid_slugs:
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug=slug,
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
assert agent_type.slug == slug
|
||||
|
||||
def test_invalid_slug_uppercase(self):
|
||||
"""Test that uppercase letters in slug raise ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="Invalid-Uppercase",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
|
||||
def test_invalid_slug_special_chars(self):
|
||||
"""Test that special characters raise ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="has_underscore",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
|
||||
|
||||
class TestAgentTypeExpertiseValidation:
|
||||
"""Tests for AgentType expertise validation."""
|
||||
|
||||
def test_expertise_normalized_lowercase(self):
|
||||
"""Test that expertise is normalized to lowercase."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
expertise=["Python", "FastAPI", "PostgreSQL"],
|
||||
)
|
||||
|
||||
assert agent_type.expertise == ["python", "fastapi", "postgresql"]
|
||||
|
||||
def test_expertise_stripped(self):
|
||||
"""Test that expertise items are stripped."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
expertise=[" python ", " fastapi "],
|
||||
)
|
||||
|
||||
assert agent_type.expertise == ["python", "fastapi"]
|
||||
|
||||
def test_expertise_empty_strings_removed(self):
|
||||
"""Test that empty expertise strings are removed."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
expertise=["python", "", " ", "fastapi"],
|
||||
)
|
||||
|
||||
assert agent_type.expertise == ["python", "fastapi"]
|
||||
|
||||
|
||||
class TestAgentTypeMcpServersValidation:
|
||||
"""Tests for AgentType MCP servers validation."""
|
||||
|
||||
def test_mcp_servers_stripped(self):
|
||||
"""Test that MCP server names are stripped."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
mcp_servers=[" gitea ", " slack "],
|
||||
)
|
||||
|
||||
assert agent_type.mcp_servers == ["gitea", "slack"]
|
||||
|
||||
def test_mcp_servers_empty_strings_removed(self):
|
||||
"""Test that empty MCP server strings are removed."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
mcp_servers=["gitea", "", " ", "slack"],
|
||||
)
|
||||
|
||||
assert agent_type.mcp_servers == ["gitea", "slack"]
|
||||
|
||||
|
||||
class TestAgentTypeUpdateValidation:
|
||||
"""Tests for AgentTypeUpdate schema validation."""
|
||||
|
||||
def test_agent_type_update_partial(self):
|
||||
"""Test updating only some fields."""
|
||||
update = AgentTypeUpdate(
|
||||
name="Updated Name",
|
||||
)
|
||||
|
||||
assert update.name == "Updated Name"
|
||||
assert update.slug is None
|
||||
assert update.description is None
|
||||
assert update.expertise is None
|
||||
|
||||
def test_agent_type_update_all_fields(self):
|
||||
"""Test updating all fields."""
|
||||
update = AgentTypeUpdate(
|
||||
name="Updated Name",
|
||||
slug="updated-slug",
|
||||
description="Updated description",
|
||||
expertise=["new-skill"],
|
||||
personality_prompt="Updated prompt",
|
||||
primary_model="new-model",
|
||||
fallback_models=["fallback-1"],
|
||||
model_params={"temp": 0.5},
|
||||
mcp_servers=["server-1"],
|
||||
tool_permissions={"key": "value"},
|
||||
is_active=False,
|
||||
)
|
||||
|
||||
assert update.name == "Updated Name"
|
||||
assert update.slug == "updated-slug"
|
||||
assert update.is_active is False
|
||||
|
||||
def test_agent_type_update_empty_name_fails(self):
|
||||
"""Test that empty name in update raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
AgentTypeUpdate(name="")
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name" in str(e) for e in errors)
|
||||
|
||||
def test_agent_type_update_slug_validation(self):
|
||||
"""Test that slug validation applies to updates."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeUpdate(slug="Invalid-Slug")
|
||||
|
||||
def test_agent_type_update_expertise_normalized(self):
|
||||
"""Test that expertise is normalized in updates."""
|
||||
update = AgentTypeUpdate(
|
||||
expertise=["Python", "FastAPI"],
|
||||
)
|
||||
|
||||
assert update.expertise == ["python", "fastapi"]
|
||||
|
||||
|
||||
class TestAgentTypeJsonFields:
|
||||
"""Tests for AgentType JSON field validation."""
|
||||
|
||||
def test_model_params_complex(self):
|
||||
"""Test complex model_params structure."""
|
||||
params = {
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 4096,
|
||||
"top_p": 0.9,
|
||||
"stop_sequences": ["###"],
|
||||
}
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
model_params=params,
|
||||
)
|
||||
|
||||
assert agent_type.model_params == params
|
||||
|
||||
def test_tool_permissions_complex(self):
|
||||
"""Test complex tool_permissions structure."""
|
||||
permissions = {
|
||||
"allowed": ["file:read", "git:commit"],
|
||||
"denied": ["file:delete"],
|
||||
"require_approval": ["git:push"],
|
||||
}
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
tool_permissions=permissions,
|
||||
)
|
||||
|
||||
assert agent_type.tool_permissions == permissions
|
||||
|
||||
def test_fallback_models_list(self):
|
||||
"""Test fallback_models as a list."""
|
||||
models = ["claude-sonnet-4-20250514", "gpt-4o", "mistral-large"]
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
fallback_models=models,
|
||||
)
|
||||
|
||||
assert agent_type.fallback_models == models
|
||||
342
backend/tests/schemas/syndarix/test_issue_schemas.py
Normal file
342
backend/tests/schemas/syndarix/test_issue_schemas.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# tests/schemas/syndarix/test_issue_schemas.py
|
||||
"""
|
||||
Tests for Issue schema validation.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.syndarix import (
|
||||
IssueAssign,
|
||||
IssueCreate,
|
||||
IssuePriority,
|
||||
IssueStatus,
|
||||
IssueUpdate,
|
||||
SyncStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestIssueCreateValidation:
|
||||
"""Tests for IssueCreate schema validation."""
|
||||
|
||||
def test_valid_issue_create(self, valid_issue_data):
|
||||
"""Test creating issue with valid data."""
|
||||
issue = IssueCreate(**valid_issue_data)
|
||||
|
||||
assert issue.title == "Test Issue"
|
||||
assert issue.body == "Issue description"
|
||||
|
||||
def test_issue_create_defaults(self, valid_issue_data):
|
||||
"""Test that defaults are applied correctly."""
|
||||
issue = IssueCreate(**valid_issue_data)
|
||||
|
||||
assert issue.status == IssueStatus.OPEN
|
||||
assert issue.priority == IssuePriority.MEDIUM
|
||||
assert issue.labels == []
|
||||
assert issue.story_points is None
|
||||
assert issue.assigned_agent_id is None
|
||||
assert issue.human_assignee is None
|
||||
assert issue.sprint_id is None
|
||||
|
||||
def test_issue_create_with_all_fields(self, valid_uuid):
|
||||
"""Test creating issue with all optional fields."""
|
||||
agent_id = uuid.uuid4()
|
||||
sprint_id = uuid.uuid4()
|
||||
|
||||
issue = IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Full Issue",
|
||||
body="Detailed body",
|
||||
status=IssueStatus.IN_PROGRESS,
|
||||
priority=IssuePriority.HIGH,
|
||||
labels=["bug", "security"],
|
||||
story_points=5,
|
||||
assigned_agent_id=agent_id,
|
||||
sprint_id=sprint_id,
|
||||
external_tracker="gitea",
|
||||
external_id="gitea-123",
|
||||
external_url="https://gitea.example.com/issues/123",
|
||||
external_number=123,
|
||||
)
|
||||
|
||||
assert issue.status == IssueStatus.IN_PROGRESS
|
||||
assert issue.priority == IssuePriority.HIGH
|
||||
assert issue.labels == ["bug", "security"]
|
||||
assert issue.story_points == 5
|
||||
assert issue.external_tracker == "gitea"
|
||||
|
||||
def test_issue_create_title_empty_fails(self, valid_uuid):
|
||||
"""Test that empty title raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("title" in str(e) for e in errors)
|
||||
|
||||
def test_issue_create_title_whitespace_only_fails(self, valid_uuid):
|
||||
"""Test that whitespace-only title raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title=" ",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("title" in str(e) for e in errors)
|
||||
|
||||
def test_issue_create_title_stripped(self, valid_uuid):
|
||||
"""Test that title is stripped."""
|
||||
issue = IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title=" Padded Title ",
|
||||
)
|
||||
|
||||
assert issue.title == "Padded Title"
|
||||
|
||||
def test_issue_create_project_id_required(self):
|
||||
"""Test that project_id is required."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IssueCreate(title="No Project Issue")
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("project_id" in str(e).lower() for e in errors)
|
||||
|
||||
|
||||
class TestIssueLabelsValidation:
|
||||
"""Tests for Issue labels validation."""
|
||||
|
||||
def test_labels_normalized_lowercase(self, valid_uuid):
|
||||
"""Test that labels are normalized to lowercase."""
|
||||
issue = IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Test Issue",
|
||||
labels=["Bug", "SECURITY", "FrontEnd"],
|
||||
)
|
||||
|
||||
assert issue.labels == ["bug", "security", "frontend"]
|
||||
|
||||
def test_labels_stripped(self, valid_uuid):
|
||||
"""Test that labels are stripped."""
|
||||
issue = IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Test Issue",
|
||||
labels=[" bug ", " security "],
|
||||
)
|
||||
|
||||
assert issue.labels == ["bug", "security"]
|
||||
|
||||
def test_labels_empty_strings_removed(self, valid_uuid):
|
||||
"""Test that empty label strings are removed."""
|
||||
issue = IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Test Issue",
|
||||
labels=["bug", "", " ", "security"],
|
||||
)
|
||||
|
||||
assert issue.labels == ["bug", "security"]
|
||||
|
||||
|
||||
class TestIssueStoryPointsValidation:
|
||||
"""Tests for Issue story_points validation."""
|
||||
|
||||
def test_story_points_valid_range(self, valid_uuid):
|
||||
"""Test valid story_points values."""
|
||||
for points in [0, 1, 5, 13, 21, 100]:
|
||||
issue = IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Test Issue",
|
||||
story_points=points,
|
||||
)
|
||||
assert issue.story_points == points
|
||||
|
||||
def test_story_points_negative_fails(self, valid_uuid):
|
||||
"""Test that negative story_points raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Test Issue",
|
||||
story_points=-1,
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("story_points" in str(e).lower() for e in errors)
|
||||
|
||||
def test_story_points_over_100_fails(self, valid_uuid):
|
||||
"""Test that story_points > 100 raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Test Issue",
|
||||
story_points=101,
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("story_points" in str(e).lower() for e in errors)
|
||||
|
||||
|
||||
class TestIssueExternalTrackerValidation:
|
||||
"""Tests for Issue external tracker validation."""
|
||||
|
||||
def test_valid_external_trackers(self, valid_uuid):
|
||||
"""Test valid external tracker values."""
|
||||
for tracker in ["gitea", "github", "gitlab"]:
|
||||
issue = IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Test Issue",
|
||||
external_tracker=tracker,
|
||||
external_id="ext-123",
|
||||
)
|
||||
assert issue.external_tracker == tracker
|
||||
|
||||
def test_invalid_external_tracker(self, valid_uuid):
|
||||
"""Test that invalid external tracker raises ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Test Issue",
|
||||
external_tracker="invalid", # type: ignore
|
||||
external_id="ext-123",
|
||||
)
|
||||
|
||||
|
||||
class TestIssueUpdateValidation:
|
||||
"""Tests for IssueUpdate schema validation."""
|
||||
|
||||
def test_issue_update_partial(self):
|
||||
"""Test updating only some fields."""
|
||||
update = IssueUpdate(
|
||||
title="Updated Title",
|
||||
)
|
||||
|
||||
assert update.title == "Updated Title"
|
||||
assert update.body is None
|
||||
assert update.status is None
|
||||
|
||||
def test_issue_update_all_fields(self):
|
||||
"""Test updating all fields."""
|
||||
agent_id = uuid.uuid4()
|
||||
sprint_id = uuid.uuid4()
|
||||
|
||||
update = IssueUpdate(
|
||||
title="Updated Title",
|
||||
body="Updated body",
|
||||
status=IssueStatus.CLOSED,
|
||||
priority=IssuePriority.CRITICAL,
|
||||
labels=["updated"],
|
||||
assigned_agent_id=agent_id,
|
||||
human_assignee=None,
|
||||
sprint_id=sprint_id,
|
||||
story_points=8,
|
||||
sync_status=SyncStatus.PENDING,
|
||||
)
|
||||
|
||||
assert update.title == "Updated Title"
|
||||
assert update.status == IssueStatus.CLOSED
|
||||
assert update.priority == IssuePriority.CRITICAL
|
||||
|
||||
def test_issue_update_empty_title_fails(self):
|
||||
"""Test that empty title in update raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IssueUpdate(title="")
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("title" in str(e) for e in errors)
|
||||
|
||||
def test_issue_update_labels_normalized(self):
|
||||
"""Test that labels are normalized in updates."""
|
||||
update = IssueUpdate(
|
||||
labels=["Bug", "SECURITY"],
|
||||
)
|
||||
|
||||
assert update.labels == ["bug", "security"]
|
||||
|
||||
|
||||
class TestIssueAssignValidation:
|
||||
"""Tests for IssueAssign schema validation."""
|
||||
|
||||
def test_assign_to_agent(self):
|
||||
"""Test assigning to an agent."""
|
||||
agent_id = uuid.uuid4()
|
||||
assign = IssueAssign(assigned_agent_id=agent_id)
|
||||
|
||||
assert assign.assigned_agent_id == agent_id
|
||||
assert assign.human_assignee is None
|
||||
|
||||
def test_assign_to_human(self):
|
||||
"""Test assigning to a human."""
|
||||
assign = IssueAssign(human_assignee="developer@example.com")
|
||||
|
||||
assert assign.human_assignee == "developer@example.com"
|
||||
assert assign.assigned_agent_id is None
|
||||
|
||||
def test_unassign(self):
|
||||
"""Test unassigning (both None)."""
|
||||
assign = IssueAssign()
|
||||
|
||||
assert assign.assigned_agent_id is None
|
||||
assert assign.human_assignee is None
|
||||
|
||||
def test_assign_both_fails(self):
|
||||
"""Test that assigning to both agent and human raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
IssueAssign(
|
||||
assigned_agent_id=uuid.uuid4(),
|
||||
human_assignee="developer@example.com",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
# Check for the validation error message
|
||||
assert len(errors) > 0
|
||||
|
||||
|
||||
class TestIssueEnums:
|
||||
"""Tests for Issue enum validation."""
|
||||
|
||||
def test_valid_issue_statuses(self, valid_uuid):
|
||||
"""Test all valid issue statuses."""
|
||||
for status in IssueStatus:
|
||||
issue = IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title=f"Issue {status.value}",
|
||||
status=status,
|
||||
)
|
||||
assert issue.status == status
|
||||
|
||||
def test_invalid_issue_status(self, valid_uuid):
|
||||
"""Test that invalid issue status raises ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Test Issue",
|
||||
status="invalid", # type: ignore
|
||||
)
|
||||
|
||||
def test_valid_issue_priorities(self, valid_uuid):
|
||||
"""Test all valid issue priorities."""
|
||||
for priority in IssuePriority:
|
||||
issue = IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title=f"Issue {priority.value}",
|
||||
priority=priority,
|
||||
)
|
||||
assert issue.priority == priority
|
||||
|
||||
def test_invalid_issue_priority(self, valid_uuid):
|
||||
"""Test that invalid issue priority raises ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
IssueCreate(
|
||||
project_id=valid_uuid,
|
||||
title="Test Issue",
|
||||
priority="invalid", # type: ignore
|
||||
)
|
||||
|
||||
def test_valid_sync_statuses(self):
|
||||
"""Test all valid sync statuses in update."""
|
||||
for status in SyncStatus:
|
||||
update = IssueUpdate(sync_status=status)
|
||||
assert update.sync_status == status
|
||||
300
backend/tests/schemas/syndarix/test_project_schemas.py
Normal file
300
backend/tests/schemas/syndarix/test_project_schemas.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# tests/schemas/syndarix/test_project_schemas.py
|
||||
"""
|
||||
Tests for Project schema validation.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.syndarix import (
|
||||
AutonomyLevel,
|
||||
ProjectCreate,
|
||||
ProjectStatus,
|
||||
ProjectUpdate,
|
||||
)
|
||||
|
||||
|
||||
class TestProjectCreateValidation:
|
||||
"""Tests for ProjectCreate schema validation."""
|
||||
|
||||
def test_valid_project_create(self, valid_project_data):
|
||||
"""Test creating project with valid data."""
|
||||
project = ProjectCreate(**valid_project_data)
|
||||
|
||||
assert project.name == "Test Project"
|
||||
assert project.slug == "test-project"
|
||||
assert project.description == "A test project"
|
||||
|
||||
def test_project_create_defaults(self):
|
||||
"""Test that defaults are applied correctly."""
|
||||
project = ProjectCreate(
|
||||
name="Minimal Project",
|
||||
slug="minimal-project",
|
||||
)
|
||||
|
||||
assert project.autonomy_level == AutonomyLevel.MILESTONE
|
||||
assert project.status == ProjectStatus.ACTIVE
|
||||
assert project.settings == {}
|
||||
assert project.owner_id is None
|
||||
|
||||
def test_project_create_with_owner(self, valid_project_data):
|
||||
"""Test creating project with owner ID."""
|
||||
owner_id = uuid.uuid4()
|
||||
project = ProjectCreate(
|
||||
**valid_project_data,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
assert project.owner_id == owner_id
|
||||
|
||||
def test_project_create_name_empty_fails(self):
|
||||
"""Test that empty name raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProjectCreate(
|
||||
name="",
|
||||
slug="valid-slug",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name" in str(e) for e in errors)
|
||||
|
||||
def test_project_create_name_whitespace_only_fails(self):
|
||||
"""Test that whitespace-only name raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProjectCreate(
|
||||
name=" ",
|
||||
slug="valid-slug",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name" in str(e) for e in errors)
|
||||
|
||||
def test_project_create_name_stripped(self):
|
||||
"""Test that name is stripped of leading/trailing whitespace."""
|
||||
project = ProjectCreate(
|
||||
name=" Padded Name ",
|
||||
slug="padded-slug",
|
||||
)
|
||||
|
||||
assert project.name == "Padded Name"
|
||||
|
||||
def test_project_create_slug_required(self):
|
||||
"""Test that slug is required for create."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProjectCreate(name="No Slug Project")
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("slug" in str(e).lower() for e in errors)
|
||||
|
||||
|
||||
class TestProjectSlugValidation:
|
||||
"""Tests for Project slug validation."""
|
||||
|
||||
def test_valid_slugs(self):
|
||||
"""Test various valid slug formats."""
|
||||
valid_slugs = [
|
||||
"simple",
|
||||
"with-hyphens",
|
||||
"has123numbers",
|
||||
"mix3d-with-hyphen5",
|
||||
"a", # Single character
|
||||
]
|
||||
|
||||
for slug in valid_slugs:
|
||||
project = ProjectCreate(
|
||||
name="Test Project",
|
||||
slug=slug,
|
||||
)
|
||||
assert project.slug == slug
|
||||
|
||||
def test_invalid_slug_uppercase(self):
|
||||
"""Test that uppercase letters in slug raise ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProjectCreate(
|
||||
name="Test Project",
|
||||
slug="Invalid-Uppercase",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("slug" in str(e).lower() for e in errors)
|
||||
|
||||
def test_invalid_slug_special_chars(self):
|
||||
"""Test that special characters in slug raise ValidationError."""
|
||||
invalid_slugs = [
|
||||
"has_underscore",
|
||||
"has.dot",
|
||||
"has@symbol",
|
||||
"has space",
|
||||
"has/slash",
|
||||
]
|
||||
|
||||
for slug in invalid_slugs:
|
||||
with pytest.raises(ValidationError):
|
||||
ProjectCreate(
|
||||
name="Test Project",
|
||||
slug=slug,
|
||||
)
|
||||
|
||||
def test_invalid_slug_starts_with_hyphen(self):
|
||||
"""Test that slug starting with hyphen raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProjectCreate(
|
||||
name="Test Project",
|
||||
slug="-invalid-start",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("hyphen" in str(e).lower() for e in errors)
|
||||
|
||||
def test_invalid_slug_ends_with_hyphen(self):
|
||||
"""Test that slug ending with hyphen raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProjectCreate(
|
||||
name="Test Project",
|
||||
slug="invalid-end-",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("hyphen" in str(e).lower() for e in errors)
|
||||
|
||||
def test_invalid_slug_consecutive_hyphens(self):
|
||||
"""Test that consecutive hyphens in slug raise ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProjectCreate(
|
||||
name="Test Project",
|
||||
slug="invalid--consecutive",
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("consecutive" in str(e).lower() for e in errors)
|
||||
|
||||
|
||||
class TestProjectUpdateValidation:
|
||||
"""Tests for ProjectUpdate schema validation."""
|
||||
|
||||
def test_project_update_partial(self):
|
||||
"""Test updating only some fields."""
|
||||
update = ProjectUpdate(
|
||||
name="Updated Name",
|
||||
)
|
||||
|
||||
assert update.name == "Updated Name"
|
||||
assert update.slug is None
|
||||
assert update.description is None
|
||||
assert update.autonomy_level is None
|
||||
assert update.status is None
|
||||
|
||||
def test_project_update_all_fields(self):
|
||||
"""Test updating all fields."""
|
||||
owner_id = uuid.uuid4()
|
||||
update = ProjectUpdate(
|
||||
name="Updated Name",
|
||||
slug="updated-slug",
|
||||
description="Updated description",
|
||||
autonomy_level=AutonomyLevel.AUTONOMOUS,
|
||||
status=ProjectStatus.PAUSED,
|
||||
settings={"key": "value"},
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
assert update.name == "Updated Name"
|
||||
assert update.slug == "updated-slug"
|
||||
assert update.autonomy_level == AutonomyLevel.AUTONOMOUS
|
||||
assert update.status == ProjectStatus.PAUSED
|
||||
|
||||
def test_project_update_empty_name_fails(self):
|
||||
"""Test that empty name in update raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
ProjectUpdate(name="")
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name" in str(e) for e in errors)
|
||||
|
||||
def test_project_update_slug_validation(self):
|
||||
"""Test that slug validation applies to updates too."""
|
||||
with pytest.raises(ValidationError):
|
||||
ProjectUpdate(slug="Invalid-Slug")
|
||||
|
||||
|
||||
class TestProjectEnums:
|
||||
"""Tests for Project enum validation."""
|
||||
|
||||
def test_valid_autonomy_levels(self):
|
||||
"""Test all valid autonomy levels."""
|
||||
for level in AutonomyLevel:
|
||||
# Replace underscores with hyphens for valid slug
|
||||
slug_suffix = level.value.replace("_", "-")
|
||||
project = ProjectCreate(
|
||||
name="Test Project",
|
||||
slug=f"project-{slug_suffix}",
|
||||
autonomy_level=level,
|
||||
)
|
||||
assert project.autonomy_level == level
|
||||
|
||||
def test_invalid_autonomy_level(self):
|
||||
"""Test that invalid autonomy level raises ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
ProjectCreate(
|
||||
name="Test Project",
|
||||
slug="invalid-autonomy",
|
||||
autonomy_level="invalid", # type: ignore
|
||||
)
|
||||
|
||||
def test_valid_project_statuses(self):
|
||||
"""Test all valid project statuses."""
|
||||
for status in ProjectStatus:
|
||||
project = ProjectCreate(
|
||||
name="Test Project",
|
||||
slug=f"project-status-{status.value}",
|
||||
status=status,
|
||||
)
|
||||
assert project.status == status
|
||||
|
||||
def test_invalid_project_status(self):
|
||||
"""Test that invalid project status raises ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
ProjectCreate(
|
||||
name="Test Project",
|
||||
slug="invalid-status",
|
||||
status="invalid", # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class TestProjectSettings:
|
||||
"""Tests for Project settings validation."""
|
||||
|
||||
def test_settings_empty_dict(self):
|
||||
"""Test that empty settings dict is valid."""
|
||||
project = ProjectCreate(
|
||||
name="Test Project",
|
||||
slug="empty-settings",
|
||||
settings={},
|
||||
)
|
||||
assert project.settings == {}
|
||||
|
||||
def test_settings_complex_structure(self):
|
||||
"""Test that complex settings structure is valid."""
|
||||
complex_settings = {
|
||||
"mcp_servers": ["gitea", "slack"],
|
||||
"webhooks": {
|
||||
"on_issue_created": "https://example.com",
|
||||
},
|
||||
"flags": True,
|
||||
"count": 42,
|
||||
}
|
||||
project = ProjectCreate(
|
||||
name="Test Project",
|
||||
slug="complex-settings",
|
||||
settings=complex_settings,
|
||||
)
|
||||
assert project.settings == complex_settings
|
||||
|
||||
def test_settings_default_to_empty_dict(self):
|
||||
"""Test that settings default to empty dict when not provided."""
|
||||
project = ProjectCreate(
|
||||
name="Test Project",
|
||||
slug="default-settings",
|
||||
)
|
||||
assert project.settings == {}
|
||||
366
backend/tests/schemas/syndarix/test_sprint_schemas.py
Normal file
366
backend/tests/schemas/syndarix/test_sprint_schemas.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# tests/schemas/syndarix/test_sprint_schemas.py
|
||||
"""
|
||||
Tests for Sprint schema validation.
|
||||
"""
|
||||
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.schemas.syndarix import (
|
||||
SprintCreate,
|
||||
SprintStatus,
|
||||
SprintUpdate,
|
||||
)
|
||||
|
||||
|
||||
class TestSprintCreateValidation:
|
||||
"""Tests for SprintCreate schema validation."""
|
||||
|
||||
def test_valid_sprint_create(self, valid_sprint_data):
|
||||
"""Test creating sprint with valid data."""
|
||||
sprint = SprintCreate(**valid_sprint_data)
|
||||
|
||||
assert sprint.name == "Sprint 1"
|
||||
assert sprint.number == 1
|
||||
assert sprint.start_date is not None
|
||||
assert sprint.end_date is not None
|
||||
|
||||
def test_sprint_create_defaults(self, valid_sprint_data):
|
||||
"""Test that defaults are applied correctly."""
|
||||
sprint = SprintCreate(**valid_sprint_data)
|
||||
|
||||
assert sprint.status == SprintStatus.PLANNED
|
||||
assert sprint.goal is None
|
||||
assert sprint.planned_points is None
|
||||
assert sprint.completed_points is None
|
||||
|
||||
def test_sprint_create_with_all_fields(self, valid_uuid):
|
||||
"""Test creating sprint with all optional fields."""
|
||||
today = date.today()
|
||||
|
||||
sprint = SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name="Full Sprint",
|
||||
number=5,
|
||||
goal="Complete all features",
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=SprintStatus.PLANNED,
|
||||
planned_points=21,
|
||||
completed_points=0,
|
||||
)
|
||||
|
||||
assert sprint.name == "Full Sprint"
|
||||
assert sprint.number == 5
|
||||
assert sprint.goal == "Complete all features"
|
||||
assert sprint.planned_points == 21
|
||||
|
||||
def test_sprint_create_name_empty_fails(self, valid_uuid):
|
||||
"""Test that empty name raises ValidationError."""
|
||||
today = date.today()
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name="",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name" in str(e) for e in errors)
|
||||
|
||||
def test_sprint_create_name_whitespace_only_fails(self, valid_uuid):
|
||||
"""Test that whitespace-only name raises ValidationError."""
|
||||
today = date.today()
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name=" ",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name" in str(e) for e in errors)
|
||||
|
||||
def test_sprint_create_name_stripped(self, valid_uuid):
|
||||
"""Test that name is stripped."""
|
||||
today = date.today()
|
||||
|
||||
sprint = SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name=" Padded Sprint Name ",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
|
||||
assert sprint.name == "Padded Sprint Name"
|
||||
|
||||
def test_sprint_create_project_id_required(self):
|
||||
"""Test that project_id is required."""
|
||||
today = date.today()
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SprintCreate(
|
||||
name="Sprint 1",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("project_id" in str(e).lower() for e in errors)
|
||||
|
||||
|
||||
class TestSprintNumberValidation:
|
||||
"""Tests for Sprint number validation."""
|
||||
|
||||
def test_sprint_number_valid(self, valid_uuid):
|
||||
"""Test valid sprint numbers."""
|
||||
today = date.today()
|
||||
|
||||
for number in [1, 10, 100]:
|
||||
sprint = SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name=f"Sprint {number}",
|
||||
number=number,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
assert sprint.number == number
|
||||
|
||||
def test_sprint_number_zero_fails(self, valid_uuid):
|
||||
"""Test that sprint number 0 raises ValidationError."""
|
||||
today = date.today()
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name="Sprint Zero",
|
||||
number=0,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("number" in str(e).lower() for e in errors)
|
||||
|
||||
def test_sprint_number_negative_fails(self, valid_uuid):
|
||||
"""Test that negative sprint number raises ValidationError."""
|
||||
today = date.today()
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name="Negative Sprint",
|
||||
number=-1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("number" in str(e).lower() for e in errors)
|
||||
|
||||
|
||||
class TestSprintDateValidation:
|
||||
"""Tests for Sprint date validation."""
|
||||
|
||||
def test_valid_date_range(self, valid_uuid):
|
||||
"""Test valid date range (end > start)."""
|
||||
today = date.today()
|
||||
|
||||
sprint = SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name="Sprint 1",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
)
|
||||
|
||||
assert sprint.end_date > sprint.start_date
|
||||
|
||||
def test_same_day_sprint(self, valid_uuid):
|
||||
"""Test that same day sprint is valid."""
|
||||
today = date.today()
|
||||
|
||||
sprint = SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name="One Day Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today, # Same day is allowed
|
||||
)
|
||||
|
||||
assert sprint.start_date == sprint.end_date
|
||||
|
||||
def test_end_before_start_fails(self, valid_uuid):
|
||||
"""Test that end date before start date raises ValidationError."""
|
||||
today = date.today()
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name="Invalid Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today - timedelta(days=1), # Before start
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert len(errors) > 0
|
||||
|
||||
|
||||
class TestSprintPointsValidation:
|
||||
"""Tests for Sprint points validation."""
|
||||
|
||||
def test_valid_planned_points(self, valid_uuid):
|
||||
"""Test valid planned_points values."""
|
||||
today = date.today()
|
||||
|
||||
for points in [0, 1, 21, 100]:
|
||||
sprint = SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name=f"Sprint {points}",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
planned_points=points,
|
||||
)
|
||||
assert sprint.planned_points == points
|
||||
|
||||
def test_planned_points_negative_fails(self, valid_uuid):
|
||||
"""Test that negative planned_points raises ValidationError."""
|
||||
today = date.today()
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name="Negative Points Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
planned_points=-1,
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("planned_points" in str(e).lower() for e in errors)
|
||||
|
||||
def test_valid_completed_points(self, valid_uuid):
|
||||
"""Test valid completed_points values."""
|
||||
today = date.today()
|
||||
|
||||
for points in [0, 5, 21]:
|
||||
sprint = SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name=f"Sprint {points}",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
completed_points=points,
|
||||
)
|
||||
assert sprint.completed_points == points
|
||||
|
||||
def test_completed_points_negative_fails(self, valid_uuid):
|
||||
"""Test that negative completed_points raises ValidationError."""
|
||||
today = date.today()
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name="Negative Completed Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
completed_points=-1,
|
||||
)
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("completed_points" in str(e).lower() for e in errors)
|
||||
|
||||
|
||||
class TestSprintUpdateValidation:
|
||||
"""Tests for SprintUpdate schema validation."""
|
||||
|
||||
def test_sprint_update_partial(self):
|
||||
"""Test updating only some fields."""
|
||||
update = SprintUpdate(
|
||||
name="Updated Name",
|
||||
)
|
||||
|
||||
assert update.name == "Updated Name"
|
||||
assert update.goal is None
|
||||
assert update.start_date is None
|
||||
assert update.end_date is None
|
||||
|
||||
def test_sprint_update_all_fields(self):
|
||||
"""Test updating all fields."""
|
||||
today = date.today()
|
||||
|
||||
update = SprintUpdate(
|
||||
name="Updated Name",
|
||||
goal="Updated goal",
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=21),
|
||||
status=SprintStatus.ACTIVE,
|
||||
planned_points=34,
|
||||
completed_points=20,
|
||||
)
|
||||
|
||||
assert update.name == "Updated Name"
|
||||
assert update.goal == "Updated goal"
|
||||
assert update.status == SprintStatus.ACTIVE
|
||||
assert update.planned_points == 34
|
||||
|
||||
def test_sprint_update_empty_name_fails(self):
|
||||
"""Test that empty name in update raises ValidationError."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SprintUpdate(name="")
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
assert any("name" in str(e) for e in errors)
|
||||
|
||||
def test_sprint_update_name_stripped(self):
|
||||
"""Test that name is stripped in updates."""
|
||||
update = SprintUpdate(name=" Updated ")
|
||||
|
||||
assert update.name == "Updated"
|
||||
|
||||
|
||||
class TestSprintStatusEnum:
|
||||
"""Tests for SprintStatus enum validation."""
|
||||
|
||||
def test_valid_sprint_statuses(self, valid_uuid):
|
||||
"""Test all valid sprint statuses."""
|
||||
today = date.today()
|
||||
|
||||
for status in SprintStatus:
|
||||
sprint = SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name=f"Sprint {status.value}",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status=status,
|
||||
)
|
||||
assert sprint.status == status
|
||||
|
||||
def test_invalid_sprint_status(self, valid_uuid):
|
||||
"""Test that invalid sprint status raises ValidationError."""
|
||||
today = date.today()
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
SprintCreate(
|
||||
project_id=valid_uuid,
|
||||
name="Invalid Status Sprint",
|
||||
number=1,
|
||||
start_date=today,
|
||||
end_date=today + timedelta(days=14),
|
||||
status="invalid", # type: ignore
|
||||
)
|
||||
1035
backend/tests/services/test_event_bus.py
Normal file
1035
backend/tests/services/test_event_bus.py
Normal file
File diff suppressed because it is too large
Load Diff
11
backend/tests/tasks/__init__.py
Normal file
11
backend/tests/tasks/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# tests/tasks/__init__.py
|
||||
"""
|
||||
Tests for Celery background tasks.
|
||||
|
||||
This module tests the Celery configuration and all task modules:
|
||||
- agent: Agent execution tasks
|
||||
- git: Git operation tasks
|
||||
- sync: Issue synchronization tasks
|
||||
- workflow: Workflow state management tasks
|
||||
- cost: Cost tracking and aggregation tasks
|
||||
"""
|
||||
358
backend/tests/tasks/test_agent_tasks.py
Normal file
358
backend/tests/tasks/test_agent_tasks.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# tests/tasks/test_agent_tasks.py
|
||||
"""
|
||||
Tests for agent execution tasks.
|
||||
|
||||
These tests verify:
|
||||
- Task signatures are correctly defined
|
||||
- Tasks are bound (have access to self)
|
||||
- Tasks return expected structure
|
||||
- Tasks handle various input scenarios
|
||||
|
||||
Note: These tests mock actual execution since they would require
|
||||
LLM calls and database access in production.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import uuid
|
||||
|
||||
|
||||
class TestRunAgentStepTask:
|
||||
"""Tests for the run_agent_step task."""
|
||||
|
||||
def test_run_agent_step_task_exists(self):
|
||||
"""Test that run_agent_step task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.agent # noqa: F401
|
||||
|
||||
assert "app.tasks.agent.run_agent_step" in celery_app.tasks
|
||||
|
||||
def test_run_agent_step_is_bound_task(self):
|
||||
"""Test that run_agent_step is a bound task (has access to self)."""
|
||||
from app.tasks.agent import run_agent_step
|
||||
|
||||
# Bound tasks have __bound__=True, which means they receive 'self' as first arg
|
||||
assert run_agent_step.__bound__ is True
|
||||
|
||||
def test_run_agent_step_has_correct_name(self):
|
||||
"""Test that run_agent_step has the correct task name."""
|
||||
from app.tasks.agent import run_agent_step
|
||||
|
||||
assert run_agent_step.name == "app.tasks.agent.run_agent_step"
|
||||
|
||||
def test_run_agent_step_returns_expected_structure(self):
|
||||
"""Test that run_agent_step returns the expected result structure."""
|
||||
from app.tasks.agent import run_agent_step
|
||||
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
context = {"messages": [], "tools": []}
|
||||
|
||||
# Call the task directly (synchronously for testing)
|
||||
result = run_agent_step(agent_instance_id, context)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "agent_instance_id" in result
|
||||
assert result["agent_instance_id"] == agent_instance_id
|
||||
|
||||
def test_run_agent_step_with_empty_context(self):
|
||||
"""Test that run_agent_step handles empty context."""
|
||||
from app.tasks.agent import run_agent_step
|
||||
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
context = {}
|
||||
|
||||
result = run_agent_step(agent_instance_id, context)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
assert result["agent_instance_id"] == agent_instance_id
|
||||
|
||||
def test_run_agent_step_with_complex_context(self):
|
||||
"""Test that run_agent_step handles complex context data."""
|
||||
from app.tasks.agent import run_agent_step
|
||||
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
context = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Create a new feature"},
|
||||
{"role": "assistant", "content": "I will create the feature."},
|
||||
],
|
||||
"tools": ["create_file", "edit_file", "run_tests"],
|
||||
"state": {"current_step": 3, "max_steps": 10},
|
||||
"metadata": {"project_id": str(uuid.uuid4())},
|
||||
}
|
||||
|
||||
result = run_agent_step(agent_instance_id, context)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
assert result["agent_instance_id"] == agent_instance_id
|
||||
|
||||
|
||||
class TestSpawnAgentTask:
|
||||
"""Tests for the spawn_agent task."""
|
||||
|
||||
def test_spawn_agent_task_exists(self):
|
||||
"""Test that spawn_agent task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.agent # noqa: F401
|
||||
|
||||
assert "app.tasks.agent.spawn_agent" in celery_app.tasks
|
||||
|
||||
def test_spawn_agent_is_bound_task(self):
|
||||
"""Test that spawn_agent is a bound task."""
|
||||
from app.tasks.agent import spawn_agent
|
||||
|
||||
assert spawn_agent.__bound__ is True
|
||||
|
||||
def test_spawn_agent_has_correct_name(self):
|
||||
"""Test that spawn_agent has the correct task name."""
|
||||
from app.tasks.agent import spawn_agent
|
||||
|
||||
assert spawn_agent.name == "app.tasks.agent.spawn_agent"
|
||||
|
||||
def test_spawn_agent_returns_expected_structure(self):
|
||||
"""Test that spawn_agent returns the expected result structure."""
|
||||
from app.tasks.agent import spawn_agent
|
||||
|
||||
agent_type_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
initial_context = {"goal": "Implement user story"}
|
||||
|
||||
result = spawn_agent(agent_type_id, project_id, initial_context)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "agent_type_id" in result
|
||||
assert "project_id" in result
|
||||
assert result["status"] == "spawned"
|
||||
assert result["agent_type_id"] == agent_type_id
|
||||
assert result["project_id"] == project_id
|
||||
|
||||
def test_spawn_agent_with_empty_initial_context(self):
|
||||
"""Test that spawn_agent handles empty initial context."""
|
||||
from app.tasks.agent import spawn_agent
|
||||
|
||||
agent_type_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
initial_context = {}
|
||||
|
||||
result = spawn_agent(agent_type_id, project_id, initial_context)
|
||||
|
||||
assert result["status"] == "spawned"
|
||||
|
||||
def test_spawn_agent_with_detailed_initial_context(self):
|
||||
"""Test that spawn_agent handles detailed initial context."""
|
||||
from app.tasks.agent import spawn_agent
|
||||
|
||||
agent_type_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
initial_context = {
|
||||
"goal": "Implement authentication",
|
||||
"constraints": ["Must use JWT", "Must support MFA"],
|
||||
"assigned_issues": [str(uuid.uuid4()), str(uuid.uuid4())],
|
||||
"autonomy_level": "MILESTONE",
|
||||
}
|
||||
|
||||
result = spawn_agent(agent_type_id, project_id, initial_context)
|
||||
|
||||
assert result["status"] == "spawned"
|
||||
assert result["agent_type_id"] == agent_type_id
|
||||
assert result["project_id"] == project_id
|
||||
|
||||
|
||||
class TestTerminateAgentTask:
|
||||
"""Tests for the terminate_agent task."""
|
||||
|
||||
def test_terminate_agent_task_exists(self):
|
||||
"""Test that terminate_agent task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.agent # noqa: F401
|
||||
|
||||
assert "app.tasks.agent.terminate_agent" in celery_app.tasks
|
||||
|
||||
def test_terminate_agent_is_bound_task(self):
|
||||
"""Test that terminate_agent is a bound task."""
|
||||
from app.tasks.agent import terminate_agent
|
||||
|
||||
assert terminate_agent.__bound__ is True
|
||||
|
||||
def test_terminate_agent_has_correct_name(self):
|
||||
"""Test that terminate_agent has the correct task name."""
|
||||
from app.tasks.agent import terminate_agent
|
||||
|
||||
assert terminate_agent.name == "app.tasks.agent.terminate_agent"
|
||||
|
||||
def test_terminate_agent_returns_expected_structure(self):
|
||||
"""Test that terminate_agent returns the expected result structure."""
|
||||
from app.tasks.agent import terminate_agent
|
||||
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
reason = "Task completed successfully"
|
||||
|
||||
result = terminate_agent(agent_instance_id, reason)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "agent_instance_id" in result
|
||||
assert result["status"] == "terminated"
|
||||
assert result["agent_instance_id"] == agent_instance_id
|
||||
|
||||
def test_terminate_agent_with_error_reason(self):
|
||||
"""Test that terminate_agent handles error termination reasons."""
|
||||
from app.tasks.agent import terminate_agent
|
||||
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
reason = "Error: Budget limit exceeded"
|
||||
|
||||
result = terminate_agent(agent_instance_id, reason)
|
||||
|
||||
assert result["status"] == "terminated"
|
||||
assert result["agent_instance_id"] == agent_instance_id
|
||||
|
||||
def test_terminate_agent_with_empty_reason(self):
|
||||
"""Test that terminate_agent handles empty reason string."""
|
||||
from app.tasks.agent import terminate_agent
|
||||
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
reason = ""
|
||||
|
||||
result = terminate_agent(agent_instance_id, reason)
|
||||
|
||||
assert result["status"] == "terminated"
|
||||
|
||||
|
||||
class TestAgentTaskRouting:
|
||||
"""Tests for agent task queue routing."""
|
||||
|
||||
def test_agent_tasks_should_route_to_agent_queue(self):
|
||||
"""Test that agent tasks are configured to route to 'agent' queue."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
routes = celery_app.conf.task_routes
|
||||
agent_route = routes.get("app.tasks.agent.*")
|
||||
|
||||
assert agent_route is not None
|
||||
assert agent_route["queue"] == "agent"
|
||||
|
||||
def test_run_agent_step_routing(self):
|
||||
"""Test that run_agent_step task routes to agent queue."""
|
||||
from app.tasks.agent import run_agent_step
|
||||
from app.celery_app import celery_app
|
||||
|
||||
# Get the routing configuration for this specific task
|
||||
task_name = run_agent_step.name
|
||||
routes = celery_app.conf.task_routes
|
||||
|
||||
# The task name matches the pattern "app.tasks.agent.*"
|
||||
assert task_name.startswith("app.tasks.agent.")
|
||||
assert "app.tasks.agent.*" in routes
|
||||
assert routes["app.tasks.agent.*"]["queue"] == "agent"
|
||||
|
||||
|
||||
class TestAgentTaskSignatures:
|
||||
"""Tests for agent task signature creation (for async invocation)."""
|
||||
|
||||
def test_run_agent_step_signature_creation(self):
|
||||
"""Test that run_agent_step signature can be created."""
|
||||
from app.tasks.agent import run_agent_step
|
||||
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
context = {"messages": []}
|
||||
|
||||
# Create a signature (delayed task)
|
||||
sig = run_agent_step.s(agent_instance_id, context)
|
||||
|
||||
assert sig is not None
|
||||
assert sig.args == (agent_instance_id, context)
|
||||
|
||||
def test_spawn_agent_signature_creation(self):
|
||||
"""Test that spawn_agent signature can be created."""
|
||||
from app.tasks.agent import spawn_agent
|
||||
|
||||
agent_type_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
initial_context = {}
|
||||
|
||||
sig = spawn_agent.s(agent_type_id, project_id, initial_context)
|
||||
|
||||
assert sig is not None
|
||||
assert sig.args == (agent_type_id, project_id, initial_context)
|
||||
|
||||
def test_terminate_agent_signature_creation(self):
|
||||
"""Test that terminate_agent signature can be created."""
|
||||
from app.tasks.agent import terminate_agent
|
||||
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
reason = "User requested termination"
|
||||
|
||||
sig = terminate_agent.s(agent_instance_id, reason)
|
||||
|
||||
assert sig is not None
|
||||
assert sig.args == (agent_instance_id, reason)
|
||||
|
||||
def test_agent_task_chain_creation(self):
|
||||
"""Test that agent tasks can be chained together."""
|
||||
from celery import chain
|
||||
from app.tasks.agent import spawn_agent, run_agent_step, terminate_agent
|
||||
|
||||
# Create a chain of tasks (this doesn't execute, just builds the chain)
|
||||
agent_type_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
|
||||
# Note: In real usage, the chain would pass results between tasks
|
||||
workflow = chain(
|
||||
spawn_agent.s(agent_type_id, project_id, {}),
|
||||
# Further tasks would use the result from spawn_agent
|
||||
)
|
||||
|
||||
assert workflow is not None
|
||||
|
||||
|
||||
class TestAgentTaskLogging:
|
||||
"""Tests for agent task logging behavior."""
|
||||
|
||||
def test_run_agent_step_logs_execution(self):
|
||||
"""Test that run_agent_step logs when executed."""
|
||||
from app.tasks.agent import run_agent_step
|
||||
import logging
|
||||
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
context = {}
|
||||
|
||||
with patch("app.tasks.agent.logger") as mock_logger:
|
||||
run_agent_step(agent_instance_id, context)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert agent_instance_id in call_args
|
||||
|
||||
def test_spawn_agent_logs_execution(self):
|
||||
"""Test that spawn_agent logs when executed."""
|
||||
from app.tasks.agent import spawn_agent
|
||||
|
||||
agent_type_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.tasks.agent.logger") as mock_logger:
|
||||
spawn_agent(agent_type_id, project_id, {})
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert agent_type_id in call_args
|
||||
assert project_id in call_args
|
||||
|
||||
def test_terminate_agent_logs_execution(self):
|
||||
"""Test that terminate_agent logs when executed."""
|
||||
from app.tasks.agent import terminate_agent
|
||||
|
||||
agent_instance_id = str(uuid.uuid4())
|
||||
reason = "Test termination"
|
||||
|
||||
with patch("app.tasks.agent.logger") as mock_logger:
|
||||
terminate_agent(agent_instance_id, reason)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert agent_instance_id in call_args
|
||||
assert reason in call_args
|
||||
321
backend/tests/tasks/test_celery_config.py
Normal file
321
backend/tests/tasks/test_celery_config.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# tests/tasks/test_celery_config.py
|
||||
"""
|
||||
Tests for Celery application configuration.
|
||||
|
||||
These tests verify:
|
||||
- Celery app is properly configured
|
||||
- Queue routing is correctly set up per ADR-003
|
||||
- Task discovery works for all task modules
|
||||
- Beat schedule is configured for periodic tasks
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestCeleryAppConfiguration:
|
||||
"""Tests for the Celery application instance configuration."""
|
||||
|
||||
def test_celery_app_is_created_with_correct_name(self):
|
||||
"""Test that the Celery app is created with 'syndarix' as the name."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
assert celery_app.main == "syndarix"
|
||||
|
||||
def test_celery_app_uses_redis_broker(self):
|
||||
"""Test that Celery is configured to use Redis as the broker."""
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
|
||||
# The broker URL should match the settings
|
||||
assert celery_app.conf.broker_url == settings.celery_broker_url
|
||||
|
||||
def test_celery_app_uses_redis_backend(self):
|
||||
"""Test that Celery is configured to use Redis as the result backend."""
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
|
||||
assert celery_app.conf.result_backend == settings.celery_result_backend
|
||||
|
||||
def test_celery_uses_json_serialization(self):
|
||||
"""Test that Celery is configured to use JSON for serialization."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
assert celery_app.conf.task_serializer == "json"
|
||||
assert celery_app.conf.result_serializer == "json"
|
||||
assert "json" in celery_app.conf.accept_content
|
||||
|
||||
def test_celery_uses_utc_timezone(self):
|
||||
"""Test that Celery is configured to use UTC timezone."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
assert celery_app.conf.timezone == "UTC"
|
||||
assert celery_app.conf.enable_utc is True
|
||||
|
||||
def test_celery_has_late_ack_enabled(self):
|
||||
"""Test that late acknowledgment is enabled for task reliability."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
# Per ADR-003: Late ack for reliability
|
||||
assert celery_app.conf.task_acks_late is True
|
||||
assert celery_app.conf.task_reject_on_worker_lost is True
|
||||
|
||||
def test_celery_prefetch_multiplier_is_one(self):
|
||||
"""Test that worker prefetch is set to 1 for fair task distribution."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
assert celery_app.conf.worker_prefetch_multiplier == 1
|
||||
|
||||
def test_celery_result_expiration(self):
|
||||
"""Test that results expire after 24 hours."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
# 86400 seconds = 24 hours
|
||||
assert celery_app.conf.result_expires == 86400
|
||||
|
||||
def test_celery_has_time_limits_configured(self):
|
||||
"""Test that task time limits are configured per ADR-003."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
# 5 minutes soft limit, 10 minutes hard limit
|
||||
assert celery_app.conf.task_soft_time_limit == 300
|
||||
assert celery_app.conf.task_time_limit == 600
|
||||
|
||||
def test_celery_broker_connection_retry_enabled(self):
|
||||
"""Test that broker connection retry is enabled on startup."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
assert celery_app.conf.broker_connection_retry_on_startup is True
|
||||
|
||||
|
||||
class TestQueueRoutingConfiguration:
|
||||
"""Tests for Celery queue routing configuration per ADR-003."""
|
||||
|
||||
def test_default_queue_is_configured(self):
|
||||
"""Test that 'default' is set as the default queue."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
assert celery_app.conf.task_default_queue == "default"
|
||||
|
||||
def test_task_routes_are_configured(self):
|
||||
"""Test that task routes are properly configured."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
routes = celery_app.conf.task_routes
|
||||
assert routes is not None
|
||||
assert isinstance(routes, dict)
|
||||
|
||||
def test_agent_tasks_routed_to_agent_queue(self):
|
||||
"""Test that agent tasks are routed to the 'agent' queue."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
routes = celery_app.conf.task_routes
|
||||
assert "app.tasks.agent.*" in routes
|
||||
assert routes["app.tasks.agent.*"]["queue"] == "agent"
|
||||
|
||||
def test_git_tasks_routed_to_git_queue(self):
|
||||
"""Test that git tasks are routed to the 'git' queue."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
routes = celery_app.conf.task_routes
|
||||
assert "app.tasks.git.*" in routes
|
||||
assert routes["app.tasks.git.*"]["queue"] == "git"
|
||||
|
||||
def test_sync_tasks_routed_to_sync_queue(self):
|
||||
"""Test that sync tasks are routed to the 'sync' queue."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
routes = celery_app.conf.task_routes
|
||||
assert "app.tasks.sync.*" in routes
|
||||
assert routes["app.tasks.sync.*"]["queue"] == "sync"
|
||||
|
||||
def test_default_tasks_routed_to_default_queue(self):
|
||||
"""Test that unmatched tasks are routed to the 'default' queue."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
routes = celery_app.conf.task_routes
|
||||
assert "app.tasks.*" in routes
|
||||
assert routes["app.tasks.*"]["queue"] == "default"
|
||||
|
||||
def test_all_queues_are_defined(self):
|
||||
"""Test that all expected queues are defined in task_queues."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
queues = celery_app.conf.task_queues
|
||||
expected_queues = {"agent", "git", "sync", "default"}
|
||||
|
||||
assert queues is not None
|
||||
assert set(queues.keys()) == expected_queues
|
||||
|
||||
def test_queue_exchanges_are_configured(self):
|
||||
"""Test that each queue has its own exchange configured."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
queues = celery_app.conf.task_queues
|
||||
|
||||
for queue_name in ["agent", "git", "sync", "default"]:
|
||||
assert queue_name in queues
|
||||
assert queues[queue_name]["exchange"] == queue_name
|
||||
assert queues[queue_name]["routing_key"] == queue_name
|
||||
|
||||
|
||||
class TestTaskDiscovery:
|
||||
"""Tests for Celery task auto-discovery."""
|
||||
|
||||
def test_task_imports_are_configured(self):
|
||||
"""Test that task imports are configured for auto-discovery."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
imports = celery_app.conf.imports
|
||||
assert imports is not None
|
||||
assert "app.tasks" in imports
|
||||
|
||||
def test_agent_tasks_are_discoverable(self):
|
||||
"""Test that agent tasks can be discovered and accessed."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
# Force task registration by importing
|
||||
import app.tasks.agent # noqa: F401
|
||||
|
||||
# Check that agent tasks are registered
|
||||
registered_tasks = celery_app.tasks
|
||||
|
||||
assert "app.tasks.agent.run_agent_step" in registered_tasks
|
||||
assert "app.tasks.agent.spawn_agent" in registered_tasks
|
||||
assert "app.tasks.agent.terminate_agent" in registered_tasks
|
||||
|
||||
def test_git_tasks_are_discoverable(self):
|
||||
"""Test that git tasks can be discovered and accessed."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
# Force task registration by importing
|
||||
import app.tasks.git # noqa: F401
|
||||
|
||||
registered_tasks = celery_app.tasks
|
||||
|
||||
assert "app.tasks.git.clone_repository" in registered_tasks
|
||||
assert "app.tasks.git.commit_changes" in registered_tasks
|
||||
assert "app.tasks.git.create_branch" in registered_tasks
|
||||
assert "app.tasks.git.create_pull_request" in registered_tasks
|
||||
assert "app.tasks.git.push_changes" in registered_tasks
|
||||
|
||||
def test_sync_tasks_are_discoverable(self):
|
||||
"""Test that sync tasks can be discovered and accessed."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
# Force task registration by importing
|
||||
import app.tasks.sync # noqa: F401
|
||||
|
||||
registered_tasks = celery_app.tasks
|
||||
|
||||
assert "app.tasks.sync.sync_issues_incremental" in registered_tasks
|
||||
assert "app.tasks.sync.sync_issues_full" in registered_tasks
|
||||
assert "app.tasks.sync.process_webhook_event" in registered_tasks
|
||||
assert "app.tasks.sync.sync_project_issues" in registered_tasks
|
||||
assert "app.tasks.sync.push_issue_to_external" in registered_tasks
|
||||
|
||||
def test_workflow_tasks_are_discoverable(self):
|
||||
"""Test that workflow tasks can be discovered and accessed."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
# Force task registration by importing
|
||||
import app.tasks.workflow # noqa: F401
|
||||
|
||||
registered_tasks = celery_app.tasks
|
||||
|
||||
assert "app.tasks.workflow.recover_stale_workflows" in registered_tasks
|
||||
assert "app.tasks.workflow.execute_workflow_step" in registered_tasks
|
||||
assert "app.tasks.workflow.handle_approval_response" in registered_tasks
|
||||
assert "app.tasks.workflow.start_sprint_workflow" in registered_tasks
|
||||
assert "app.tasks.workflow.start_story_workflow" in registered_tasks
|
||||
|
||||
def test_cost_tasks_are_discoverable(self):
|
||||
"""Test that cost tasks can be discovered and accessed."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
# Force task registration by importing
|
||||
import app.tasks.cost # noqa: F401
|
||||
|
||||
registered_tasks = celery_app.tasks
|
||||
|
||||
assert "app.tasks.cost.aggregate_daily_costs" in registered_tasks
|
||||
assert "app.tasks.cost.check_budget_thresholds" in registered_tasks
|
||||
assert "app.tasks.cost.record_llm_usage" in registered_tasks
|
||||
assert "app.tasks.cost.generate_cost_report" in registered_tasks
|
||||
assert "app.tasks.cost.reset_daily_budget_counters" in registered_tasks
|
||||
|
||||
|
||||
class TestBeatSchedule:
|
||||
"""Tests for Celery Beat scheduled tasks configuration."""
|
||||
|
||||
def test_beat_schedule_is_configured(self):
|
||||
"""Test that beat_schedule is configured."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
assert celery_app.conf.beat_schedule is not None
|
||||
assert isinstance(celery_app.conf.beat_schedule, dict)
|
||||
|
||||
def test_incremental_sync_is_scheduled(self):
|
||||
"""Test that incremental issue sync is scheduled per ADR-011."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
schedule = celery_app.conf.beat_schedule
|
||||
assert "sync-issues-incremental" in schedule
|
||||
|
||||
task_config = schedule["sync-issues-incremental"]
|
||||
assert task_config["task"] == "app.tasks.sync.sync_issues_incremental"
|
||||
assert task_config["schedule"] == 60.0 # Every 60 seconds
|
||||
|
||||
def test_full_sync_is_scheduled(self):
|
||||
"""Test that full issue sync is scheduled per ADR-011."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
schedule = celery_app.conf.beat_schedule
|
||||
assert "sync-issues-full" in schedule
|
||||
|
||||
task_config = schedule["sync-issues-full"]
|
||||
assert task_config["task"] == "app.tasks.sync.sync_issues_full"
|
||||
assert task_config["schedule"] == 900.0 # Every 15 minutes
|
||||
|
||||
def test_stale_workflow_recovery_is_scheduled(self):
|
||||
"""Test that stale workflow recovery is scheduled per ADR-007."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
schedule = celery_app.conf.beat_schedule
|
||||
assert "recover-stale-workflows" in schedule
|
||||
|
||||
task_config = schedule["recover-stale-workflows"]
|
||||
assert task_config["task"] == "app.tasks.workflow.recover_stale_workflows"
|
||||
assert task_config["schedule"] == 300.0 # Every 5 minutes
|
||||
|
||||
def test_daily_cost_aggregation_is_scheduled(self):
|
||||
"""Test that daily cost aggregation is scheduled per ADR-012."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
schedule = celery_app.conf.beat_schedule
|
||||
assert "aggregate-daily-costs" in schedule
|
||||
|
||||
task_config = schedule["aggregate-daily-costs"]
|
||||
assert task_config["task"] == "app.tasks.cost.aggregate_daily_costs"
|
||||
assert task_config["schedule"] == 3600.0 # Every hour
|
||||
|
||||
|
||||
class TestTaskModuleExports:
|
||||
"""Tests for the task module __init__.py exports."""
|
||||
|
||||
def test_tasks_package_exports_all_modules(self):
|
||||
"""Test that the tasks package exports all task modules."""
|
||||
from app import tasks
|
||||
|
||||
assert hasattr(tasks, "agent")
|
||||
assert hasattr(tasks, "git")
|
||||
assert hasattr(tasks, "sync")
|
||||
assert hasattr(tasks, "workflow")
|
||||
assert hasattr(tasks, "cost")
|
||||
|
||||
def test_tasks_all_attribute_is_correct(self):
|
||||
"""Test that __all__ contains all expected module names."""
|
||||
from app import tasks
|
||||
|
||||
expected_modules = ["agent", "git", "sync", "workflow", "cost"]
|
||||
assert set(tasks.__all__) == set(expected_modules)
|
||||
379
backend/tests/tasks/test_cost_tasks.py
Normal file
379
backend/tests/tasks/test_cost_tasks.py
Normal file
@@ -0,0 +1,379 @@
|
||||
# tests/tasks/test_cost_tasks.py
|
||||
"""
|
||||
Tests for cost tracking and budget management tasks.
|
||||
|
||||
These tests verify:
|
||||
- Task signatures are correctly defined
|
||||
- Tasks are bound (have access to self)
|
||||
- Tasks return expected structure
|
||||
- Tasks follow ADR-012 (multi-layered cost tracking)
|
||||
|
||||
Note: These tests mock actual execution since they would require
|
||||
database access and Redis operations in production.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import uuid
|
||||
|
||||
|
||||
class TestAggregateDailyCostsTask:
|
||||
"""Tests for the aggregate_daily_costs task."""
|
||||
|
||||
def test_aggregate_daily_costs_task_exists(self):
|
||||
"""Test that aggregate_daily_costs task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.cost # noqa: F401
|
||||
|
||||
assert "app.tasks.cost.aggregate_daily_costs" in celery_app.tasks
|
||||
|
||||
def test_aggregate_daily_costs_is_bound_task(self):
|
||||
"""Test that aggregate_daily_costs is a bound task."""
|
||||
from app.tasks.cost import aggregate_daily_costs
|
||||
|
||||
assert aggregate_daily_costs.__bound__ is True
|
||||
|
||||
def test_aggregate_daily_costs_has_correct_name(self):
|
||||
"""Test that aggregate_daily_costs has the correct task name."""
|
||||
from app.tasks.cost import aggregate_daily_costs
|
||||
|
||||
assert aggregate_daily_costs.name == "app.tasks.cost.aggregate_daily_costs"
|
||||
|
||||
def test_aggregate_daily_costs_returns_expected_structure(self):
|
||||
"""Test that aggregate_daily_costs returns expected result."""
|
||||
from app.tasks.cost import aggregate_daily_costs
|
||||
|
||||
result = aggregate_daily_costs()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestCheckBudgetThresholdsTask:
|
||||
"""Tests for the check_budget_thresholds task."""
|
||||
|
||||
def test_check_budget_thresholds_task_exists(self):
|
||||
"""Test that check_budget_thresholds task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.cost # noqa: F401
|
||||
|
||||
assert "app.tasks.cost.check_budget_thresholds" in celery_app.tasks
|
||||
|
||||
def test_check_budget_thresholds_is_bound_task(self):
|
||||
"""Test that check_budget_thresholds is a bound task."""
|
||||
from app.tasks.cost import check_budget_thresholds
|
||||
|
||||
assert check_budget_thresholds.__bound__ is True
|
||||
|
||||
def test_check_budget_thresholds_returns_expected_structure(self):
|
||||
"""Test that check_budget_thresholds returns expected result."""
|
||||
from app.tasks.cost import check_budget_thresholds
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
result = check_budget_thresholds(project_id)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "project_id" in result
|
||||
assert result["project_id"] == project_id
|
||||
|
||||
|
||||
class TestRecordLlmUsageTask:
|
||||
"""Tests for the record_llm_usage task."""
|
||||
|
||||
def test_record_llm_usage_task_exists(self):
|
||||
"""Test that record_llm_usage task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.cost # noqa: F401
|
||||
|
||||
assert "app.tasks.cost.record_llm_usage" in celery_app.tasks
|
||||
|
||||
def test_record_llm_usage_is_bound_task(self):
|
||||
"""Test that record_llm_usage is a bound task."""
|
||||
from app.tasks.cost import record_llm_usage
|
||||
|
||||
assert record_llm_usage.__bound__ is True
|
||||
|
||||
def test_record_llm_usage_returns_expected_structure(self):
|
||||
"""Test that record_llm_usage returns expected result."""
|
||||
from app.tasks.cost import record_llm_usage
|
||||
|
||||
agent_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
model = "claude-opus-4-5-20251101"
|
||||
prompt_tokens = 1500
|
||||
completion_tokens = 500
|
||||
cost_usd = 0.0825
|
||||
|
||||
result = record_llm_usage(
|
||||
agent_id, project_id, model, prompt_tokens, completion_tokens, cost_usd
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "agent_id" in result
|
||||
assert "project_id" in result
|
||||
assert "cost_usd" in result
|
||||
assert result["agent_id"] == agent_id
|
||||
assert result["project_id"] == project_id
|
||||
assert result["cost_usd"] == cost_usd
|
||||
|
||||
def test_record_llm_usage_with_different_models(self):
|
||||
"""Test that record_llm_usage handles different model types."""
|
||||
from app.tasks.cost import record_llm_usage
|
||||
|
||||
agent_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
models = [
|
||||
("claude-opus-4-5-20251101", 0.015),
|
||||
("claude-sonnet-4-20250514", 0.003),
|
||||
("gpt-4-turbo", 0.01),
|
||||
("gemini-1.5-pro", 0.007),
|
||||
]
|
||||
|
||||
for model, cost in models:
|
||||
result = record_llm_usage(
|
||||
agent_id, project_id, model, 1000, 500, cost
|
||||
)
|
||||
assert result["status"] == "pending"
|
||||
|
||||
def test_record_llm_usage_with_zero_tokens(self):
|
||||
"""Test that record_llm_usage handles zero token counts."""
|
||||
from app.tasks.cost import record_llm_usage
|
||||
|
||||
agent_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
result = record_llm_usage(
|
||||
agent_id, project_id, "claude-opus-4-5-20251101", 0, 0, 0.0
|
||||
)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestGenerateCostReportTask:
|
||||
"""Tests for the generate_cost_report task."""
|
||||
|
||||
def test_generate_cost_report_task_exists(self):
|
||||
"""Test that generate_cost_report task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.cost # noqa: F401
|
||||
|
||||
assert "app.tasks.cost.generate_cost_report" in celery_app.tasks
|
||||
|
||||
def test_generate_cost_report_is_bound_task(self):
|
||||
"""Test that generate_cost_report is a bound task."""
|
||||
from app.tasks.cost import generate_cost_report
|
||||
|
||||
assert generate_cost_report.__bound__ is True
|
||||
|
||||
def test_generate_cost_report_returns_expected_structure(self):
|
||||
"""Test that generate_cost_report returns expected result."""
|
||||
from app.tasks.cost import generate_cost_report
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
start_date = "2025-01-01"
|
||||
end_date = "2025-01-31"
|
||||
|
||||
result = generate_cost_report(project_id, start_date, end_date)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "project_id" in result
|
||||
assert "start_date" in result
|
||||
assert "end_date" in result
|
||||
assert result["project_id"] == project_id
|
||||
assert result["start_date"] == start_date
|
||||
assert result["end_date"] == end_date
|
||||
|
||||
def test_generate_cost_report_with_various_date_ranges(self):
|
||||
"""Test that generate_cost_report handles various date ranges."""
|
||||
from app.tasks.cost import generate_cost_report
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
date_ranges = [
|
||||
("2025-01-01", "2025-01-01"), # Single day
|
||||
("2025-01-01", "2025-01-07"), # Week
|
||||
("2025-01-01", "2025-12-31"), # Full year
|
||||
]
|
||||
|
||||
for start, end in date_ranges:
|
||||
result = generate_cost_report(project_id, start, end)
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestResetDailyBudgetCountersTask:
|
||||
"""Tests for the reset_daily_budget_counters task."""
|
||||
|
||||
def test_reset_daily_budget_counters_task_exists(self):
|
||||
"""Test that reset_daily_budget_counters task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.cost # noqa: F401
|
||||
|
||||
assert "app.tasks.cost.reset_daily_budget_counters" in celery_app.tasks
|
||||
|
||||
def test_reset_daily_budget_counters_is_bound_task(self):
|
||||
"""Test that reset_daily_budget_counters is a bound task."""
|
||||
from app.tasks.cost import reset_daily_budget_counters
|
||||
|
||||
assert reset_daily_budget_counters.__bound__ is True
|
||||
|
||||
def test_reset_daily_budget_counters_returns_expected_structure(self):
|
||||
"""Test that reset_daily_budget_counters returns expected result."""
|
||||
from app.tasks.cost import reset_daily_budget_counters
|
||||
|
||||
result = reset_daily_budget_counters()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestCostTaskRouting:
|
||||
"""Tests for cost task queue routing."""
|
||||
|
||||
def test_cost_tasks_route_to_default_queue(self):
|
||||
"""Test that cost tasks route to 'default' queue.
|
||||
|
||||
Per the routing configuration, cost tasks match 'app.tasks.*'
|
||||
which routes to the default queue.
|
||||
"""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
routes = celery_app.conf.task_routes
|
||||
|
||||
# Cost tasks match the generic 'app.tasks.*' pattern
|
||||
assert "app.tasks.*" in routes
|
||||
assert routes["app.tasks.*"]["queue"] == "default"
|
||||
|
||||
def test_all_cost_tasks_match_routing_pattern(self):
|
||||
"""Test that all cost task names match the routing pattern."""
|
||||
task_names = [
|
||||
"app.tasks.cost.aggregate_daily_costs",
|
||||
"app.tasks.cost.check_budget_thresholds",
|
||||
"app.tasks.cost.record_llm_usage",
|
||||
"app.tasks.cost.generate_cost_report",
|
||||
"app.tasks.cost.reset_daily_budget_counters",
|
||||
]
|
||||
|
||||
for name in task_names:
|
||||
assert name.startswith("app.tasks.")
|
||||
|
||||
|
||||
class TestCostTaskLogging:
|
||||
"""Tests for cost task logging behavior."""
|
||||
|
||||
def test_aggregate_daily_costs_logs_execution(self):
|
||||
"""Test that aggregate_daily_costs logs when executed."""
|
||||
from app.tasks.cost import aggregate_daily_costs
|
||||
|
||||
with patch("app.tasks.cost.logger") as mock_logger:
|
||||
aggregate_daily_costs()
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert "cost" in call_args.lower() or "aggregat" in call_args.lower()
|
||||
|
||||
def test_check_budget_thresholds_logs_execution(self):
|
||||
"""Test that check_budget_thresholds logs when executed."""
|
||||
from app.tasks.cost import check_budget_thresholds
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.tasks.cost.logger") as mock_logger:
|
||||
check_budget_thresholds(project_id)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert project_id in call_args
|
||||
|
||||
def test_record_llm_usage_logs_execution(self):
|
||||
"""Test that record_llm_usage logs when executed."""
|
||||
from app.tasks.cost import record_llm_usage
|
||||
|
||||
agent_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
model = "claude-opus-4-5-20251101"
|
||||
|
||||
with patch("app.tasks.cost.logger") as mock_logger:
|
||||
record_llm_usage(agent_id, project_id, model, 100, 50, 0.01)
|
||||
|
||||
# Uses debug level, not info
|
||||
mock_logger.debug.assert_called_once()
|
||||
call_args = mock_logger.debug.call_args[0][0]
|
||||
assert model in call_args
|
||||
|
||||
def test_generate_cost_report_logs_execution(self):
|
||||
"""Test that generate_cost_report logs when executed."""
|
||||
from app.tasks.cost import generate_cost_report
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.tasks.cost.logger") as mock_logger:
|
||||
generate_cost_report(project_id, "2025-01-01", "2025-01-31")
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert project_id in call_args
|
||||
|
||||
def test_reset_daily_budget_counters_logs_execution(self):
|
||||
"""Test that reset_daily_budget_counters logs when executed."""
|
||||
from app.tasks.cost import reset_daily_budget_counters
|
||||
|
||||
with patch("app.tasks.cost.logger") as mock_logger:
|
||||
reset_daily_budget_counters()
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert "reset" in call_args.lower() or "counter" in call_args.lower()
|
||||
|
||||
|
||||
class TestCostTaskSignatures:
|
||||
"""Tests for cost task signature creation."""
|
||||
|
||||
def test_record_llm_usage_signature_creation(self):
|
||||
"""Test that record_llm_usage signature can be created."""
|
||||
from app.tasks.cost import record_llm_usage
|
||||
|
||||
agent_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
sig = record_llm_usage.s(
|
||||
agent_id, project_id, "claude-opus-4-5-20251101", 100, 50, 0.01
|
||||
)
|
||||
|
||||
assert sig is not None
|
||||
assert len(sig.args) == 6
|
||||
|
||||
def test_check_budget_thresholds_signature_creation(self):
|
||||
"""Test that check_budget_thresholds signature can be created."""
|
||||
from app.tasks.cost import check_budget_thresholds
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
sig = check_budget_thresholds.s(project_id)
|
||||
|
||||
assert sig is not None
|
||||
assert sig.args == (project_id,)
|
||||
|
||||
def test_cost_task_chain_creation(self):
|
||||
"""Test that cost tasks can be chained together."""
|
||||
from celery import chain
|
||||
from app.tasks.cost import record_llm_usage, check_budget_thresholds
|
||||
|
||||
agent_id = str(uuid.uuid4())
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
# Build a chain: record usage, then check thresholds
|
||||
workflow = chain(
|
||||
record_llm_usage.s(
|
||||
agent_id, project_id, "claude-opus-4-5-20251101", 1000, 500, 0.05
|
||||
),
|
||||
check_budget_thresholds.s(project_id),
|
||||
)
|
||||
|
||||
assert workflow is not None
|
||||
301
backend/tests/tasks/test_git_tasks.py
Normal file
301
backend/tests/tasks/test_git_tasks.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# tests/tasks/test_git_tasks.py
|
||||
"""
|
||||
Tests for git operation tasks.
|
||||
|
||||
These tests verify:
|
||||
- Task signatures are correctly defined
|
||||
- Tasks are bound (have access to self)
|
||||
- Tasks return expected structure
|
||||
- Tasks are routed to the 'git' queue
|
||||
|
||||
Note: These tests mock actual execution since they would require
|
||||
Git operations and external APIs in production.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import uuid
|
||||
|
||||
|
||||
class TestCloneRepositoryTask:
|
||||
"""Tests for the clone_repository task."""
|
||||
|
||||
def test_clone_repository_task_exists(self):
|
||||
"""Test that clone_repository task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.git # noqa: F401
|
||||
|
||||
assert "app.tasks.git.clone_repository" in celery_app.tasks
|
||||
|
||||
def test_clone_repository_is_bound_task(self):
|
||||
"""Test that clone_repository is a bound task."""
|
||||
from app.tasks.git import clone_repository
|
||||
|
||||
assert clone_repository.__bound__ is True
|
||||
|
||||
def test_clone_repository_has_correct_name(self):
|
||||
"""Test that clone_repository has the correct task name."""
|
||||
from app.tasks.git import clone_repository
|
||||
|
||||
assert clone_repository.name == "app.tasks.git.clone_repository"
|
||||
|
||||
def test_clone_repository_returns_expected_structure(self):
|
||||
"""Test that clone_repository returns the expected result structure."""
|
||||
from app.tasks.git import clone_repository
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
repo_url = "https://gitea.example.com/org/repo.git"
|
||||
branch = "main"
|
||||
|
||||
result = clone_repository(project_id, repo_url, branch)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "project_id" in result
|
||||
assert result["project_id"] == project_id
|
||||
|
||||
def test_clone_repository_with_default_branch(self):
|
||||
"""Test that clone_repository uses default branch when not specified."""
|
||||
from app.tasks.git import clone_repository
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
repo_url = "https://github.com/org/repo.git"
|
||||
|
||||
# Call without specifying branch (should default to 'main')
|
||||
result = clone_repository(project_id, repo_url)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestCommitChangesTask:
|
||||
"""Tests for the commit_changes task."""
|
||||
|
||||
def test_commit_changes_task_exists(self):
|
||||
"""Test that commit_changes task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.git # noqa: F401
|
||||
|
||||
assert "app.tasks.git.commit_changes" in celery_app.tasks
|
||||
|
||||
def test_commit_changes_is_bound_task(self):
|
||||
"""Test that commit_changes is a bound task."""
|
||||
from app.tasks.git import commit_changes
|
||||
|
||||
assert commit_changes.__bound__ is True
|
||||
|
||||
def test_commit_changes_returns_expected_structure(self):
|
||||
"""Test that commit_changes returns the expected result structure."""
|
||||
from app.tasks.git import commit_changes
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
message = "feat: Add new feature"
|
||||
files = ["src/feature.py", "tests/test_feature.py"]
|
||||
|
||||
result = commit_changes(project_id, message, files)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "project_id" in result
|
||||
|
||||
def test_commit_changes_without_files(self):
|
||||
"""Test that commit_changes handles None files (commit all staged)."""
|
||||
from app.tasks.git import commit_changes
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
message = "chore: Update dependencies"
|
||||
|
||||
result = commit_changes(project_id, message, None)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestCreateBranchTask:
|
||||
"""Tests for the create_branch task."""
|
||||
|
||||
def test_create_branch_task_exists(self):
|
||||
"""Test that create_branch task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.git # noqa: F401
|
||||
|
||||
assert "app.tasks.git.create_branch" in celery_app.tasks
|
||||
|
||||
def test_create_branch_is_bound_task(self):
|
||||
"""Test that create_branch is a bound task."""
|
||||
from app.tasks.git import create_branch
|
||||
|
||||
assert create_branch.__bound__ is True
|
||||
|
||||
def test_create_branch_returns_expected_structure(self):
|
||||
"""Test that create_branch returns the expected result structure."""
|
||||
from app.tasks.git import create_branch
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
branch_name = "feature/new-feature"
|
||||
from_ref = "develop"
|
||||
|
||||
result = create_branch(project_id, branch_name, from_ref)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "project_id" in result
|
||||
|
||||
def test_create_branch_with_default_from_ref(self):
|
||||
"""Test that create_branch uses default from_ref when not specified."""
|
||||
from app.tasks.git import create_branch
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
branch_name = "feature/123-add-login"
|
||||
|
||||
result = create_branch(project_id, branch_name)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestCreatePullRequestTask:
|
||||
"""Tests for the create_pull_request task."""
|
||||
|
||||
def test_create_pull_request_task_exists(self):
|
||||
"""Test that create_pull_request task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.git # noqa: F401
|
||||
|
||||
assert "app.tasks.git.create_pull_request" in celery_app.tasks
|
||||
|
||||
def test_create_pull_request_is_bound_task(self):
|
||||
"""Test that create_pull_request is a bound task."""
|
||||
from app.tasks.git import create_pull_request
|
||||
|
||||
assert create_pull_request.__bound__ is True
|
||||
|
||||
def test_create_pull_request_returns_expected_structure(self):
|
||||
"""Test that create_pull_request returns expected result structure."""
|
||||
from app.tasks.git import create_pull_request
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
title = "feat: Add authentication"
|
||||
body = "## Summary\n- Added JWT auth\n- Added login endpoint"
|
||||
head_branch = "feature/auth"
|
||||
base_branch = "main"
|
||||
|
||||
result = create_pull_request(project_id, title, body, head_branch, base_branch)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "project_id" in result
|
||||
|
||||
def test_create_pull_request_with_default_base(self):
|
||||
"""Test that create_pull_request uses default base branch."""
|
||||
from app.tasks.git import create_pull_request
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
result = create_pull_request(
|
||||
project_id, "Fix bug", "Bug fix description", "fix/bug-123"
|
||||
)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestPushChangesTask:
|
||||
"""Tests for the push_changes task."""
|
||||
|
||||
def test_push_changes_task_exists(self):
|
||||
"""Test that push_changes task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.git # noqa: F401
|
||||
|
||||
assert "app.tasks.git.push_changes" in celery_app.tasks
|
||||
|
||||
def test_push_changes_is_bound_task(self):
|
||||
"""Test that push_changes is a bound task."""
|
||||
from app.tasks.git import push_changes
|
||||
|
||||
assert push_changes.__bound__ is True
|
||||
|
||||
def test_push_changes_returns_expected_structure(self):
|
||||
"""Test that push_changes returns the expected result structure."""
|
||||
from app.tasks.git import push_changes
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
branch = "feature/new-feature"
|
||||
force = False
|
||||
|
||||
result = push_changes(project_id, branch, force)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "project_id" in result
|
||||
|
||||
def test_push_changes_with_force_option(self):
|
||||
"""Test that push_changes handles force push option."""
|
||||
from app.tasks.git import push_changes
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
branch = "feature/rebased-branch"
|
||||
force = True
|
||||
|
||||
result = push_changes(project_id, branch, force)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestGitTaskRouting:
|
||||
"""Tests for git task queue routing."""
|
||||
|
||||
def test_git_tasks_should_route_to_git_queue(self):
|
||||
"""Test that git tasks are configured to route to 'git' queue."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
routes = celery_app.conf.task_routes
|
||||
git_route = routes.get("app.tasks.git.*")
|
||||
|
||||
assert git_route is not None
|
||||
assert git_route["queue"] == "git"
|
||||
|
||||
def test_all_git_tasks_match_routing_pattern(self):
|
||||
"""Test that all git task names match the routing pattern."""
|
||||
from app.tasks import git
|
||||
|
||||
task_names = [
|
||||
"app.tasks.git.clone_repository",
|
||||
"app.tasks.git.commit_changes",
|
||||
"app.tasks.git.create_branch",
|
||||
"app.tasks.git.create_pull_request",
|
||||
"app.tasks.git.push_changes",
|
||||
]
|
||||
|
||||
for name in task_names:
|
||||
assert name.startswith("app.tasks.git.")
|
||||
|
||||
|
||||
class TestGitTaskLogging:
|
||||
"""Tests for git task logging behavior."""
|
||||
|
||||
def test_clone_repository_logs_execution(self):
|
||||
"""Test that clone_repository logs when executed."""
|
||||
from app.tasks.git import clone_repository
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
repo_url = "https://github.com/org/repo.git"
|
||||
|
||||
with patch("app.tasks.git.logger") as mock_logger:
|
||||
clone_repository(project_id, repo_url)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert repo_url in call_args
|
||||
assert project_id in call_args
|
||||
|
||||
def test_commit_changes_logs_execution(self):
|
||||
"""Test that commit_changes logs when executed."""
|
||||
from app.tasks.git import commit_changes
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
message = "test commit"
|
||||
|
||||
with patch("app.tasks.git.logger") as mock_logger:
|
||||
commit_changes(project_id, message)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert message in call_args
|
||||
309
backend/tests/tasks/test_sync_tasks.py
Normal file
309
backend/tests/tasks/test_sync_tasks.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# tests/tasks/test_sync_tasks.py
|
||||
"""
|
||||
Tests for issue synchronization tasks.
|
||||
|
||||
These tests verify:
|
||||
- Task signatures are correctly defined
|
||||
- Tasks are bound (have access to self)
|
||||
- Tasks return expected structure
|
||||
- Tasks are routed to the 'sync' queue per ADR-011
|
||||
|
||||
Note: These tests mock actual execution since they would require
|
||||
external API calls in production.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import uuid
|
||||
|
||||
|
||||
class TestSyncIssuesIncrementalTask:
|
||||
"""Tests for the sync_issues_incremental task."""
|
||||
|
||||
def test_sync_issues_incremental_task_exists(self):
|
||||
"""Test that sync_issues_incremental task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.sync # noqa: F401
|
||||
|
||||
assert "app.tasks.sync.sync_issues_incremental" in celery_app.tasks
|
||||
|
||||
def test_sync_issues_incremental_is_bound_task(self):
|
||||
"""Test that sync_issues_incremental is a bound task."""
|
||||
from app.tasks.sync import sync_issues_incremental
|
||||
|
||||
assert sync_issues_incremental.__bound__ is True
|
||||
|
||||
def test_sync_issues_incremental_has_correct_name(self):
|
||||
"""Test that sync_issues_incremental has the correct task name."""
|
||||
from app.tasks.sync import sync_issues_incremental
|
||||
|
||||
assert sync_issues_incremental.name == "app.tasks.sync.sync_issues_incremental"
|
||||
|
||||
def test_sync_issues_incremental_returns_expected_structure(self):
|
||||
"""Test that sync_issues_incremental returns expected result."""
|
||||
from app.tasks.sync import sync_issues_incremental
|
||||
|
||||
result = sync_issues_incremental()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "type" in result
|
||||
assert result["type"] == "incremental"
|
||||
|
||||
|
||||
class TestSyncIssuesFullTask:
|
||||
"""Tests for the sync_issues_full task."""
|
||||
|
||||
def test_sync_issues_full_task_exists(self):
|
||||
"""Test that sync_issues_full task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.sync # noqa: F401
|
||||
|
||||
assert "app.tasks.sync.sync_issues_full" in celery_app.tasks
|
||||
|
||||
def test_sync_issues_full_is_bound_task(self):
|
||||
"""Test that sync_issues_full is a bound task."""
|
||||
from app.tasks.sync import sync_issues_full
|
||||
|
||||
assert sync_issues_full.__bound__ is True
|
||||
|
||||
def test_sync_issues_full_has_correct_name(self):
|
||||
"""Test that sync_issues_full has the correct task name."""
|
||||
from app.tasks.sync import sync_issues_full
|
||||
|
||||
assert sync_issues_full.name == "app.tasks.sync.sync_issues_full"
|
||||
|
||||
def test_sync_issues_full_returns_expected_structure(self):
|
||||
"""Test that sync_issues_full returns expected result."""
|
||||
from app.tasks.sync import sync_issues_full
|
||||
|
||||
result = sync_issues_full()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "type" in result
|
||||
assert result["type"] == "full"
|
||||
|
||||
|
||||
class TestProcessWebhookEventTask:
|
||||
"""Tests for the process_webhook_event task."""
|
||||
|
||||
def test_process_webhook_event_task_exists(self):
|
||||
"""Test that process_webhook_event task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.sync # noqa: F401
|
||||
|
||||
assert "app.tasks.sync.process_webhook_event" in celery_app.tasks
|
||||
|
||||
def test_process_webhook_event_is_bound_task(self):
|
||||
"""Test that process_webhook_event is a bound task."""
|
||||
from app.tasks.sync import process_webhook_event
|
||||
|
||||
assert process_webhook_event.__bound__ is True
|
||||
|
||||
def test_process_webhook_event_returns_expected_structure(self):
|
||||
"""Test that process_webhook_event returns expected result."""
|
||||
from app.tasks.sync import process_webhook_event
|
||||
|
||||
provider = "gitea"
|
||||
event_type = "issue.created"
|
||||
payload = {
|
||||
"action": "opened",
|
||||
"issue": {"number": 123, "title": "New issue"},
|
||||
}
|
||||
|
||||
result = process_webhook_event(provider, event_type, payload)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "provider" in result
|
||||
assert "event_type" in result
|
||||
assert result["provider"] == provider
|
||||
assert result["event_type"] == event_type
|
||||
|
||||
def test_process_webhook_event_handles_github_provider(self):
|
||||
"""Test that process_webhook_event handles GitHub webhooks."""
|
||||
from app.tasks.sync import process_webhook_event
|
||||
|
||||
result = process_webhook_event(
|
||||
"github", "issues", {"action": "opened", "issue": {"number": 1}}
|
||||
)
|
||||
|
||||
assert result["provider"] == "github"
|
||||
|
||||
def test_process_webhook_event_handles_gitlab_provider(self):
|
||||
"""Test that process_webhook_event handles GitLab webhooks."""
|
||||
from app.tasks.sync import process_webhook_event
|
||||
|
||||
result = process_webhook_event(
|
||||
"gitlab",
|
||||
"issue.created",
|
||||
{"object_kind": "issue", "object_attributes": {"iid": 1}},
|
||||
)
|
||||
|
||||
assert result["provider"] == "gitlab"
|
||||
|
||||
|
||||
class TestSyncProjectIssuesTask:
|
||||
"""Tests for the sync_project_issues task."""
|
||||
|
||||
def test_sync_project_issues_task_exists(self):
|
||||
"""Test that sync_project_issues task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.sync # noqa: F401
|
||||
|
||||
assert "app.tasks.sync.sync_project_issues" in celery_app.tasks
|
||||
|
||||
def test_sync_project_issues_is_bound_task(self):
|
||||
"""Test that sync_project_issues is a bound task."""
|
||||
from app.tasks.sync import sync_project_issues
|
||||
|
||||
assert sync_project_issues.__bound__ is True
|
||||
|
||||
def test_sync_project_issues_returns_expected_structure(self):
|
||||
"""Test that sync_project_issues returns expected result."""
|
||||
from app.tasks.sync import sync_project_issues
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
full = False
|
||||
|
||||
result = sync_project_issues(project_id, full)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "project_id" in result
|
||||
assert result["project_id"] == project_id
|
||||
|
||||
def test_sync_project_issues_with_full_sync(self):
|
||||
"""Test that sync_project_issues handles full sync flag."""
|
||||
from app.tasks.sync import sync_project_issues
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
|
||||
result = sync_project_issues(project_id, full=True)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestPushIssueToExternalTask:
|
||||
"""Tests for the push_issue_to_external task."""
|
||||
|
||||
def test_push_issue_to_external_task_exists(self):
|
||||
"""Test that push_issue_to_external task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.sync # noqa: F401
|
||||
|
||||
assert "app.tasks.sync.push_issue_to_external" in celery_app.tasks
|
||||
|
||||
def test_push_issue_to_external_is_bound_task(self):
|
||||
"""Test that push_issue_to_external is a bound task."""
|
||||
from app.tasks.sync import push_issue_to_external
|
||||
|
||||
assert push_issue_to_external.__bound__ is True
|
||||
|
||||
def test_push_issue_to_external_returns_expected_structure(self):
|
||||
"""Test that push_issue_to_external returns expected result."""
|
||||
from app.tasks.sync import push_issue_to_external
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
issue_id = str(uuid.uuid4())
|
||||
operation = "create"
|
||||
|
||||
result = push_issue_to_external(project_id, issue_id, operation)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "issue_id" in result
|
||||
assert "operation" in result
|
||||
assert result["issue_id"] == issue_id
|
||||
assert result["operation"] == operation
|
||||
|
||||
def test_push_issue_to_external_update_operation(self):
|
||||
"""Test that push_issue_to_external handles update operation."""
|
||||
from app.tasks.sync import push_issue_to_external
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
issue_id = str(uuid.uuid4())
|
||||
|
||||
result = push_issue_to_external(project_id, issue_id, "update")
|
||||
|
||||
assert result["operation"] == "update"
|
||||
|
||||
def test_push_issue_to_external_close_operation(self):
|
||||
"""Test that push_issue_to_external handles close operation."""
|
||||
from app.tasks.sync import push_issue_to_external
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
issue_id = str(uuid.uuid4())
|
||||
|
||||
result = push_issue_to_external(project_id, issue_id, "close")
|
||||
|
||||
assert result["operation"] == "close"
|
||||
|
||||
|
||||
class TestSyncTaskRouting:
|
||||
"""Tests for sync task queue routing."""
|
||||
|
||||
def test_sync_tasks_should_route_to_sync_queue(self):
|
||||
"""Test that sync tasks are configured to route to 'sync' queue."""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
routes = celery_app.conf.task_routes
|
||||
sync_route = routes.get("app.tasks.sync.*")
|
||||
|
||||
assert sync_route is not None
|
||||
assert sync_route["queue"] == "sync"
|
||||
|
||||
def test_all_sync_tasks_match_routing_pattern(self):
|
||||
"""Test that all sync task names match the routing pattern."""
|
||||
task_names = [
|
||||
"app.tasks.sync.sync_issues_incremental",
|
||||
"app.tasks.sync.sync_issues_full",
|
||||
"app.tasks.sync.process_webhook_event",
|
||||
"app.tasks.sync.sync_project_issues",
|
||||
"app.tasks.sync.push_issue_to_external",
|
||||
]
|
||||
|
||||
for name in task_names:
|
||||
assert name.startswith("app.tasks.sync.")
|
||||
|
||||
|
||||
class TestSyncTaskLogging:
|
||||
"""Tests for sync task logging behavior."""
|
||||
|
||||
def test_sync_issues_incremental_logs_execution(self):
|
||||
"""Test that sync_issues_incremental logs when executed."""
|
||||
from app.tasks.sync import sync_issues_incremental
|
||||
|
||||
with patch("app.tasks.sync.logger") as mock_logger:
|
||||
sync_issues_incremental()
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert "incremental" in call_args.lower()
|
||||
|
||||
def test_sync_issues_full_logs_execution(self):
|
||||
"""Test that sync_issues_full logs when executed."""
|
||||
from app.tasks.sync import sync_issues_full
|
||||
|
||||
with patch("app.tasks.sync.logger") as mock_logger:
|
||||
sync_issues_full()
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert "full" in call_args.lower() or "reconciliation" in call_args.lower()
|
||||
|
||||
def test_process_webhook_event_logs_execution(self):
|
||||
"""Test that process_webhook_event logs when executed."""
|
||||
from app.tasks.sync import process_webhook_event
|
||||
|
||||
provider = "gitea"
|
||||
event_type = "issue.updated"
|
||||
|
||||
with patch("app.tasks.sync.logger") as mock_logger:
|
||||
process_webhook_event(provider, event_type, {})
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert provider in call_args
|
||||
assert event_type in call_args
|
||||
350
backend/tests/tasks/test_workflow_tasks.py
Normal file
350
backend/tests/tasks/test_workflow_tasks.py
Normal file
@@ -0,0 +1,350 @@
|
||||
# tests/tasks/test_workflow_tasks.py
|
||||
"""
|
||||
Tests for workflow state management tasks.
|
||||
|
||||
These tests verify:
|
||||
- Task signatures are correctly defined
|
||||
- Tasks are bound (have access to self)
|
||||
- Tasks return expected structure
|
||||
- Tasks follow ADR-007 (transitions) and ADR-010 (PostgreSQL durability)
|
||||
|
||||
Note: These tests mock actual execution since they would require
|
||||
database access and state machine operations in production.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import uuid
|
||||
|
||||
|
||||
class TestRecoverStaleWorkflowsTask:
|
||||
"""Tests for the recover_stale_workflows task."""
|
||||
|
||||
def test_recover_stale_workflows_task_exists(self):
|
||||
"""Test that recover_stale_workflows task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.workflow # noqa: F401
|
||||
|
||||
assert "app.tasks.workflow.recover_stale_workflows" in celery_app.tasks
|
||||
|
||||
def test_recover_stale_workflows_is_bound_task(self):
|
||||
"""Test that recover_stale_workflows is a bound task."""
|
||||
from app.tasks.workflow import recover_stale_workflows
|
||||
|
||||
assert recover_stale_workflows.__bound__ is True
|
||||
|
||||
def test_recover_stale_workflows_has_correct_name(self):
|
||||
"""Test that recover_stale_workflows has the correct task name."""
|
||||
from app.tasks.workflow import recover_stale_workflows
|
||||
|
||||
assert (
|
||||
recover_stale_workflows.name == "app.tasks.workflow.recover_stale_workflows"
|
||||
)
|
||||
|
||||
def test_recover_stale_workflows_returns_expected_structure(self):
|
||||
"""Test that recover_stale_workflows returns expected result."""
|
||||
from app.tasks.workflow import recover_stale_workflows
|
||||
|
||||
result = recover_stale_workflows()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "recovered" in result
|
||||
assert result["status"] == "pending"
|
||||
assert result["recovered"] == 0
|
||||
|
||||
|
||||
class TestExecuteWorkflowStepTask:
|
||||
"""Tests for the execute_workflow_step task."""
|
||||
|
||||
def test_execute_workflow_step_task_exists(self):
|
||||
"""Test that execute_workflow_step task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.workflow # noqa: F401
|
||||
|
||||
assert "app.tasks.workflow.execute_workflow_step" in celery_app.tasks
|
||||
|
||||
def test_execute_workflow_step_is_bound_task(self):
|
||||
"""Test that execute_workflow_step is a bound task."""
|
||||
from app.tasks.workflow import execute_workflow_step
|
||||
|
||||
assert execute_workflow_step.__bound__ is True
|
||||
|
||||
def test_execute_workflow_step_returns_expected_structure(self):
|
||||
"""Test that execute_workflow_step returns expected result."""
|
||||
from app.tasks.workflow import execute_workflow_step
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
transition = "start_planning"
|
||||
|
||||
result = execute_workflow_step(workflow_id, transition)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "workflow_id" in result
|
||||
assert "transition" in result
|
||||
assert result["workflow_id"] == workflow_id
|
||||
assert result["transition"] == transition
|
||||
|
||||
def test_execute_workflow_step_with_various_transitions(self):
|
||||
"""Test that execute_workflow_step handles various transition types."""
|
||||
from app.tasks.workflow import execute_workflow_step
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
transitions = [
|
||||
"start",
|
||||
"complete_planning",
|
||||
"begin_implementation",
|
||||
"request_approval",
|
||||
"approve",
|
||||
"reject",
|
||||
"complete",
|
||||
]
|
||||
|
||||
for transition in transitions:
|
||||
result = execute_workflow_step(workflow_id, transition)
|
||||
assert result["transition"] == transition
|
||||
|
||||
|
||||
class TestHandleApprovalResponseTask:
|
||||
"""Tests for the handle_approval_response task."""
|
||||
|
||||
def test_handle_approval_response_task_exists(self):
|
||||
"""Test that handle_approval_response task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.workflow # noqa: F401
|
||||
|
||||
assert "app.tasks.workflow.handle_approval_response" in celery_app.tasks
|
||||
|
||||
def test_handle_approval_response_is_bound_task(self):
|
||||
"""Test that handle_approval_response is a bound task."""
|
||||
from app.tasks.workflow import handle_approval_response
|
||||
|
||||
assert handle_approval_response.__bound__ is True
|
||||
|
||||
def test_handle_approval_response_returns_expected_structure(self):
|
||||
"""Test that handle_approval_response returns expected result."""
|
||||
from app.tasks.workflow import handle_approval_response
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
approved = True
|
||||
comment = "LGTM! Proceeding with deployment."
|
||||
|
||||
result = handle_approval_response(workflow_id, approved, comment)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "workflow_id" in result
|
||||
assert "approved" in result
|
||||
assert result["workflow_id"] == workflow_id
|
||||
assert result["approved"] == approved
|
||||
|
||||
def test_handle_approval_response_with_rejection(self):
|
||||
"""Test that handle_approval_response handles rejection."""
|
||||
from app.tasks.workflow import handle_approval_response
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
|
||||
result = handle_approval_response(
|
||||
workflow_id, approved=False, comment="Needs more test coverage"
|
||||
)
|
||||
|
||||
assert result["approved"] is False
|
||||
|
||||
def test_handle_approval_response_without_comment(self):
|
||||
"""Test that handle_approval_response handles missing comment."""
|
||||
from app.tasks.workflow import handle_approval_response
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
|
||||
result = handle_approval_response(workflow_id, approved=True)
|
||||
|
||||
assert result["status"] == "pending"
|
||||
|
||||
|
||||
class TestStartSprintWorkflowTask:
|
||||
"""Tests for the start_sprint_workflow task."""
|
||||
|
||||
def test_start_sprint_workflow_task_exists(self):
|
||||
"""Test that start_sprint_workflow task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.workflow # noqa: F401
|
||||
|
||||
assert "app.tasks.workflow.start_sprint_workflow" in celery_app.tasks
|
||||
|
||||
def test_start_sprint_workflow_is_bound_task(self):
|
||||
"""Test that start_sprint_workflow is a bound task."""
|
||||
from app.tasks.workflow import start_sprint_workflow
|
||||
|
||||
assert start_sprint_workflow.__bound__ is True
|
||||
|
||||
def test_start_sprint_workflow_returns_expected_structure(self):
|
||||
"""Test that start_sprint_workflow returns expected result."""
|
||||
from app.tasks.workflow import start_sprint_workflow
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
sprint_id = str(uuid.uuid4())
|
||||
|
||||
result = start_sprint_workflow(project_id, sprint_id)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "sprint_id" in result
|
||||
assert result["sprint_id"] == sprint_id
|
||||
|
||||
|
||||
class TestStartStoryWorkflowTask:
|
||||
"""Tests for the start_story_workflow task."""
|
||||
|
||||
def test_start_story_workflow_task_exists(self):
|
||||
"""Test that start_story_workflow task is registered."""
|
||||
from app.celery_app import celery_app
|
||||
import app.tasks.workflow # noqa: F401
|
||||
|
||||
assert "app.tasks.workflow.start_story_workflow" in celery_app.tasks
|
||||
|
||||
def test_start_story_workflow_is_bound_task(self):
|
||||
"""Test that start_story_workflow is a bound task."""
|
||||
from app.tasks.workflow import start_story_workflow
|
||||
|
||||
assert start_story_workflow.__bound__ is True
|
||||
|
||||
def test_start_story_workflow_returns_expected_structure(self):
|
||||
"""Test that start_story_workflow returns expected result."""
|
||||
from app.tasks.workflow import start_story_workflow
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
story_id = str(uuid.uuid4())
|
||||
|
||||
result = start_story_workflow(project_id, story_id)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert "status" in result
|
||||
assert "story_id" in result
|
||||
assert result["story_id"] == story_id
|
||||
|
||||
|
||||
class TestWorkflowTaskRouting:
|
||||
"""Tests for workflow task queue routing."""
|
||||
|
||||
def test_workflow_tasks_route_to_default_queue(self):
|
||||
"""Test that workflow tasks route to 'default' queue.
|
||||
|
||||
Per the routing configuration, workflow tasks match 'app.tasks.*'
|
||||
which routes to the default queue.
|
||||
"""
|
||||
from app.celery_app import celery_app
|
||||
|
||||
routes = celery_app.conf.task_routes
|
||||
|
||||
# Workflow tasks match the generic 'app.tasks.*' pattern
|
||||
# since there's no specific 'app.tasks.workflow.*' route
|
||||
assert "app.tasks.*" in routes
|
||||
assert routes["app.tasks.*"]["queue"] == "default"
|
||||
|
||||
def test_all_workflow_tasks_match_routing_pattern(self):
|
||||
"""Test that all workflow task names match the routing pattern."""
|
||||
task_names = [
|
||||
"app.tasks.workflow.recover_stale_workflows",
|
||||
"app.tasks.workflow.execute_workflow_step",
|
||||
"app.tasks.workflow.handle_approval_response",
|
||||
"app.tasks.workflow.start_sprint_workflow",
|
||||
"app.tasks.workflow.start_story_workflow",
|
||||
]
|
||||
|
||||
for name in task_names:
|
||||
assert name.startswith("app.tasks.")
|
||||
|
||||
|
||||
class TestWorkflowTaskLogging:
|
||||
"""Tests for workflow task logging behavior."""
|
||||
|
||||
def test_recover_stale_workflows_logs_execution(self):
|
||||
"""Test that recover_stale_workflows logs when executed."""
|
||||
from app.tasks.workflow import recover_stale_workflows
|
||||
|
||||
with patch("app.tasks.workflow.logger") as mock_logger:
|
||||
recover_stale_workflows()
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert "stale" in call_args.lower() or "recover" in call_args.lower()
|
||||
|
||||
def test_execute_workflow_step_logs_execution(self):
|
||||
"""Test that execute_workflow_step logs when executed."""
|
||||
from app.tasks.workflow import execute_workflow_step
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
transition = "start_planning"
|
||||
|
||||
with patch("app.tasks.workflow.logger") as mock_logger:
|
||||
execute_workflow_step(workflow_id, transition)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert transition in call_args
|
||||
assert workflow_id in call_args
|
||||
|
||||
def test_handle_approval_response_logs_execution(self):
|
||||
"""Test that handle_approval_response logs when executed."""
|
||||
from app.tasks.workflow import handle_approval_response
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.tasks.workflow.logger") as mock_logger:
|
||||
handle_approval_response(workflow_id, approved=True)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert workflow_id in call_args
|
||||
|
||||
def test_start_sprint_workflow_logs_execution(self):
|
||||
"""Test that start_sprint_workflow logs when executed."""
|
||||
from app.tasks.workflow import start_sprint_workflow
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
sprint_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.tasks.workflow.logger") as mock_logger:
|
||||
start_sprint_workflow(project_id, sprint_id)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args[0][0]
|
||||
assert sprint_id in call_args
|
||||
|
||||
|
||||
class TestWorkflowTaskSignatures:
|
||||
"""Tests for workflow task signature creation."""
|
||||
|
||||
def test_execute_workflow_step_signature_creation(self):
|
||||
"""Test that execute_workflow_step signature can be created."""
|
||||
from app.tasks.workflow import execute_workflow_step
|
||||
|
||||
workflow_id = str(uuid.uuid4())
|
||||
transition = "approve"
|
||||
|
||||
sig = execute_workflow_step.s(workflow_id, transition)
|
||||
|
||||
assert sig is not None
|
||||
assert sig.args == (workflow_id, transition)
|
||||
|
||||
def test_workflow_chain_creation(self):
|
||||
"""Test that workflow tasks can be chained together."""
|
||||
from celery import chain
|
||||
from app.tasks.workflow import (
|
||||
start_sprint_workflow,
|
||||
execute_workflow_step,
|
||||
handle_approval_response,
|
||||
)
|
||||
|
||||
project_id = str(uuid.uuid4())
|
||||
sprint_id = str(uuid.uuid4())
|
||||
workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Build a chain (doesn't execute, just creates the workflow)
|
||||
workflow = chain(
|
||||
start_sprint_workflow.s(project_id, sprint_id),
|
||||
# In reality, these would use results from previous tasks
|
||||
)
|
||||
|
||||
assert workflow is not None
|
||||
159
backend/uv.lock
generated
159
backend/uv.lock
generated
@@ -28,6 +28,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/32/7df1d81ec2e50fb661944a35183d87e62d3f6c6d9f8aff64a4f245226d55/alembic-1.17.1-py3-none-any.whl", hash = "sha256:cbc2386e60f89608bb63f30d2d6cc66c7aaed1fe105bd862828600e5ad167023", size = 247848, upload-time = "2025-10-29T00:23:18.79Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "amqp"
|
||||
version = "5.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "vine" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/79/fc/ec94a357dfc6683d8c86f8b4cfa5416a4c36b28052ec8260c77aca96a443/amqp-5.3.1.tar.gz", hash = "sha256:cddc00c725449522023bad949f70fff7b48f0b1ade74d170a6f10ab044739432", size = 129013, upload-time = "2024-11-12T19:55:44.051Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/26/99/fc813cd978842c26c82534010ea849eee9ab3a13ea2b74e95cb9c99e747b/amqp-5.3.1-py3-none-any.whl", hash = "sha256:43b3319e1b4e7d1251833a93d672b4af1e40f3d632d479b98661a95f117880a2", size = 50944, upload-time = "2024-11-12T19:55:41.782Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
version = "0.0.3"
|
||||
@@ -160,6 +172,40 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/76/b9/d51d34e6cd6d887adddb28a8680a1d34235cc45b9d6e238ce39b98199ca0/bcrypt-4.2.1-cp39-abi3-win_amd64.whl", hash = "sha256:e84e0e6f8e40a242b11bce56c313edc2be121cec3e0ec2d76fce01f6af33c07c", size = 153078, upload-time = "2024-11-19T20:08:01.436Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "billiard"
|
||||
version = "4.2.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/58/23/b12ac0bcdfb7360d664f40a00b1bda139cbbbced012c34e375506dbd0143/billiard-4.2.4.tar.gz", hash = "sha256:55f542c371209e03cd5862299b74e52e4fbcba8250ba611ad94276b369b6a85f", size = 156537, upload-time = "2025-11-30T13:28:48.52Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/87/8bab77b323f16d67be364031220069f79159117dd5e43eeb4be2fef1ac9b/billiard-4.2.4-py3-none-any.whl", hash = "sha256:525b42bdec68d2b983347ac312f892db930858495db601b5836ac24e6477cde5", size = 87070, upload-time = "2025-11-30T13:28:47.016Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "celery"
|
||||
version = "5.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "billiard" },
|
||||
{ name = "click" },
|
||||
{ name = "click-didyoumean" },
|
||||
{ name = "click-plugins" },
|
||||
{ name = "click-repl" },
|
||||
{ name = "kombu" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "tzlocal" },
|
||||
{ name = "vine" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e2/1b/b9bbe49b1f799d0ee3de91c66e6b61d095139721f5a2ae25585f49d7c7a9/celery-5.6.1.tar.gz", hash = "sha256:bdc9e02b1480dd137f2df392358c3e94bb623d4f47ae1bc0a7dc5821c90089c7", size = 1716388, upload-time = "2025-12-29T21:48:50.805Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/87/b1/7b7d1e0bc2a3f7ee01576008e3c943f3f23a56809b63f4140ddc96f201c1/celery-5.6.1-py3-none-any.whl", hash = "sha256:ee87aa14d344c655fe83bfc44b2c93bbb7cba39ae11e58b88279523506159d44", size = 445358, upload-time = "2025-12-29T21:48:48.894Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
redis = [
|
||||
{ name = "kombu", extra = ["redis"] },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2025.10.5"
|
||||
@@ -295,6 +341,43 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/db/d3/9dcc0f5797f070ec8edf30fbadfb200e71d9db6b84d211e3b2085a7589a0/click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc", size = 107295, upload-time = "2025-09-18T17:32:22.42Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click-didyoumean"
|
||||
version = "0.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/30/ce/217289b77c590ea1e7c24242d9ddd6e249e52c795ff10fac2c50062c48cb/click_didyoumean-0.3.1.tar.gz", hash = "sha256:4f82fdff0dbe64ef8ab2279bd6aa3f6a99c3b28c05aa09cbfc07c9d7fbb5a463", size = 3089, upload-time = "2024-03-24T08:22:07.499Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/5b/974430b5ffdb7a4f1941d13d83c64a0395114503cc357c6b9ae4ce5047ed/click_didyoumean-0.3.1-py3-none-any.whl", hash = "sha256:5c4bb6007cfea5f2fd6583a2fb6701a22a41eb98957e63d0fac41c10e7c3117c", size = 3631, upload-time = "2024-03-24T08:22:06.356Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click-plugins"
|
||||
version = "1.1.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c3/a4/34847b59150da33690a36da3681d6bbc2ec14ee9a846bc30a6746e5984e4/click_plugins-1.1.1.2.tar.gz", hash = "sha256:d7af3984a99d243c131aa1a828331e7630f4a88a9741fd05c927b204bcf92261", size = 8343, upload-time = "2025-06-25T00:47:37.555Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3d/9a/2abecb28ae875e39c8cad711eb1186d8d14eab564705325e77e4e6ab9ae5/click_plugins-1.1.1.2-py2.py3-none-any.whl", hash = "sha256:008d65743833ffc1f5417bf0e78e8d2c23aab04d9745ba817bd3e71b0feb6aa6", size = 11051, upload-time = "2025-06-25T00:47:36.731Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click-repl"
|
||||
version = "0.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
{ name = "prompt-toolkit" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/cb/a2/57f4ac79838cfae6912f997b4d1a64a858fb0c86d7fcaae6f7b58d267fca/click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9", size = 10449, upload-time = "2023-06-15T12:43:51.141Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/52/40/9d857001228658f0d59e97ebd4c346fe73e138c6de1bce61dc568a57c7f8/click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812", size = 10289, upload-time = "2023-06-15T12:43:48.626Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
@@ -493,6 +576,7 @@ dependencies = [
|
||||
{ name = "asyncpg" },
|
||||
{ name = "authlib" },
|
||||
{ name = "bcrypt" },
|
||||
{ name = "celery", extra = ["redis"] },
|
||||
{ name = "cryptography" },
|
||||
{ name = "email-validator" },
|
||||
{ name = "fastapi" },
|
||||
@@ -509,6 +593,7 @@ dependencies = [
|
||||
{ name = "pytz" },
|
||||
{ name = "slowapi" },
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "sse-starlette" },
|
||||
{ name = "starlette" },
|
||||
{ name = "starlette-csrf" },
|
||||
{ name = "tenacity" },
|
||||
@@ -540,6 +625,7 @@ requires-dist = [
|
||||
{ name = "asyncpg", specifier = ">=0.29.0" },
|
||||
{ name = "authlib", specifier = ">=1.3.0" },
|
||||
{ name = "bcrypt", specifier = "==4.2.1" },
|
||||
{ name = "celery", extras = ["redis"], specifier = ">=5.4.0" },
|
||||
{ name = "cryptography", specifier = "==44.0.1" },
|
||||
{ name = "email-validator", specifier = ">=2.1.0.post1" },
|
||||
{ name = "fastapi", specifier = ">=0.115.8" },
|
||||
@@ -565,6 +651,7 @@ requires-dist = [
|
||||
{ name = "schemathesis", marker = "extra == 'e2e'", specifier = ">=3.30.0" },
|
||||
{ name = "slowapi", specifier = ">=0.1.9" },
|
||||
{ name = "sqlalchemy", specifier = ">=2.0.29" },
|
||||
{ name = "sse-starlette", specifier = ">=3.1.1" },
|
||||
{ name = "starlette", specifier = ">=0.40.0" },
|
||||
{ name = "starlette-csrf", specifier = ">=1.4.5" },
|
||||
{ name = "tenacity", specifier = ">=8.2.3" },
|
||||
@@ -855,6 +942,26 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/93/2d896b5fd3d79b4cadd8882c06650e66d003f465c9d12c488d92853dff78/junit_xml-1.9-py2.py3-none-any.whl", hash = "sha256:ec5ca1a55aefdd76d28fcc0b135251d156c7106fa979686a4b48d62b761b4732", size = 7130, upload-time = "2020-02-22T20:41:37.661Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kombu"
|
||||
version = "5.6.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "amqp" },
|
||||
{ name = "packaging" },
|
||||
{ name = "tzdata" },
|
||||
{ name = "vine" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b6/a5/607e533ed6c83ae1a696969b8e1c137dfebd5759a2e9682e26ff1b97740b/kombu-5.6.2.tar.gz", hash = "sha256:8060497058066c6f5aed7c26d7cd0d3b574990b09de842a8c5aaed0b92cc5a55", size = 472594, upload-time = "2025-12-29T20:30:07.779Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fb/0f/834427d8c03ff1d7e867d3db3d176470c64871753252b21b4f4897d1fa45/kombu-5.6.2-py3-none-any.whl", hash = "sha256:efcfc559da324d41d61ca311b0c64965ea35b4c55cc04ee36e55386145dace93", size = 214219, upload-time = "2025-12-29T20:30:05.74Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
redis = [
|
||||
{ name = "redis" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "limits"
|
||||
version = "5.6.0"
|
||||
@@ -1111,6 +1218,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prompt-toolkit"
|
||||
version = "3.0.52"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "wcwidth" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198, upload-time = "2025-08-27T15:24:02.057Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psutil"
|
||||
version = "5.9.8"
|
||||
@@ -1486,6 +1605,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/12/de94a39c2ef588c7e6455cfbe7343d3b2dc9d6b6b2f40c4c6565744c873d/pyyaml-6.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:ebc55a14a21cb14062aa4162f906cd962b28e2e9ea38f9b4391244cd8de4ae0b", size = 149341, upload-time = "2025-09-25T21:32:56.828Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "6.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0d/d6/e8b92798a5bd67d659d51a18170e91c16ac3b59738d91894651ee255ed49/redis-6.4.0.tar.gz", hash = "sha256:b01bc7282b8444e28ec36b261df5375183bb47a07eb9c603f284e89cbc5ef010", size = 4647399, upload-time = "2025-08-07T08:10:11.441Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/02/89e2ed7e85db6c93dfa9e8f691c5087df4e3551ab39081a4d7c6d1f90e05/redis-6.4.0-py3-none-any.whl", hash = "sha256:f0544fa9604264e9464cdf4814e7d4830f74b165d52f2a330a760a88dd248b7f", size = 279847, upload-time = "2025-08-07T08:10:09.84Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "referencing"
|
||||
version = "0.37.0"
|
||||
@@ -1766,6 +1894,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/5e/6a29fa884d9fb7ddadf6b69490a9d45fded3b38541713010dad16b77d015/sqlalchemy-2.0.44-py3-none-any.whl", hash = "sha256:19de7ca1246fbef9f9d1bff8f1ab25641569df226364a0e40457dc5457c54b05", size = 1928718, upload-time = "2025-10-10T15:29:45.32Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sse-starlette"
|
||||
version = "3.1.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
{ name = "starlette" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/62/08/8f554b0e5bad3e4e880521a1686d96c05198471eed860b0eb89b57ea3636/sse_starlette-3.1.1.tar.gz", hash = "sha256:bffa531420c1793ab224f63648c059bcadc412bf9fdb1301ac8de1cf9a67b7fb", size = 24306, upload-time = "2025-12-26T15:22:53.836Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/31/4c281581a0f8de137b710a07f65518b34bcf333b201cfa06cfda9af05f8a/sse_starlette-3.1.1-py3-none-any.whl", hash = "sha256:bb38f71ae74cfd86b529907a9fda5632195dfa6ae120f214ea4c890c7ee9d436", size = 12442, upload-time = "2025-12-26T15:22:52.911Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "starlette"
|
||||
version = "0.49.3"
|
||||
@@ -1955,6 +2096,24 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/d9/d88e73ca598f4f6ff671fb5fde8a32925c2e08a637303a1d12883c7305fa/uvicorn-0.38.0-py3-none-any.whl", hash = "sha256:48c0afd214ceb59340075b4a052ea1ee91c16fbc2a9b1469cca0e54566977b02", size = 68109, upload-time = "2025-10-18T13:46:42.958Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vine"
|
||||
version = "5.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bd/e4/d07b5f29d283596b9727dd5275ccbceb63c44a1a82aa9e4bfd20426762ac/vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0", size = 48980, upload-time = "2023-11-05T08:46:53.857Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/03/ff/7c0c86c43b3cbb927e0ccc0255cb4057ceba4799cd44ae95174ce8e8b5b2/vine-5.1.0-py3-none-any.whl", hash = "sha256:40fdf3c48b2cfe1c38a49e9ae2da6fda88e4794c810050a728bd7413811fb1dc", size = 9636, upload-time = "2023-11-05T08:46:51.205Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wcwidth"
|
||||
version = "0.2.14"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/24/30/6b0809f4510673dc723187aeaf24c7f5459922d01e2f794277a3dfb90345/wcwidth-0.2.14.tar.gz", hash = "sha256:4d478375d31bc5395a3c55c40ccdf3354688364cd61c4f6adacaa9215d0b3605", size = 102293, upload-time = "2025-09-22T16:29:53.023Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webcolors"
|
||||
version = "25.10.0"
|
||||
|
||||
Reference in New Issue
Block a user