forked from cardosofelipe/fast-next-template
Compare commits
19 Commits
79cb6bfd7b
...
feature/58
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a624a94af | ||
|
|
011b21bf0a | ||
|
|
76d7de5334 | ||
|
|
1779239c07 | ||
|
|
9dfa76aa41 | ||
|
|
4ad3d20cf2 | ||
|
|
8623eb56f5 | ||
|
|
3cb6c8d13b | ||
|
|
8e16e2645e | ||
|
|
82c3a6ba47 | ||
|
|
b6c38cac88 | ||
|
|
51404216ae | ||
|
|
3f23bc3db3 | ||
|
|
a0ec5fa2cc | ||
|
|
f262d08be2 | ||
|
|
b3f371e0a3 | ||
|
|
93cc37224c | ||
|
|
5717bffd63 | ||
|
|
9339ea30a1 |
37
Makefile
37
Makefile
@@ -1,5 +1,5 @@
|
||||
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
|
||||
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all
|
||||
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all format-all
|
||||
|
||||
VERSION ?= latest
|
||||
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
|
||||
@@ -22,6 +22,9 @@ help:
|
||||
@echo " make test-cov - Run all tests with coverage reports"
|
||||
@echo " make test-integration - Run MCP integration tests (requires running stack)"
|
||||
@echo ""
|
||||
@echo "Formatting:"
|
||||
@echo " make format-all - Format code in backend + MCP servers + frontend"
|
||||
@echo ""
|
||||
@echo "Validation:"
|
||||
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
|
||||
@echo " make validate-all - Validate everything including frontend"
|
||||
@@ -44,6 +47,7 @@ help:
|
||||
@echo " cd backend && make help - Backend-specific commands"
|
||||
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
|
||||
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
|
||||
@echo " cd mcp-servers/git-ops && make - Git Operations commands"
|
||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||
|
||||
# ============================================================================
|
||||
@@ -135,6 +139,9 @@ test-mcp:
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
|
||||
@echo ""
|
||||
@echo "=== Git Operations ==="
|
||||
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v
|
||||
|
||||
test-frontend:
|
||||
@echo "Running frontend tests..."
|
||||
@@ -155,12 +162,37 @@ test-cov:
|
||||
@echo ""
|
||||
@echo "=== Knowledge Base Coverage ==="
|
||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
@echo ""
|
||||
@echo "=== Git Operations Coverage ==="
|
||||
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||
|
||||
test-integration:
|
||||
@echo "Running MCP integration tests..."
|
||||
@echo "Note: Requires running stack (make dev first)"
|
||||
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
|
||||
|
||||
# ============================================================================
|
||||
# Formatting
|
||||
# ============================================================================
|
||||
|
||||
format-all:
|
||||
@echo "Formatting backend..."
|
||||
@cd backend && make format
|
||||
@echo ""
|
||||
@echo "Formatting LLM Gateway..."
|
||||
@cd mcp-servers/llm-gateway && make format
|
||||
@echo ""
|
||||
@echo "Formatting Knowledge Base..."
|
||||
@cd mcp-servers/knowledge-base && make format
|
||||
@echo ""
|
||||
@echo "Formatting Git Operations..."
|
||||
@cd mcp-servers/git-ops && make format
|
||||
@echo ""
|
||||
@echo "Formatting frontend..."
|
||||
@cd frontend && npm run format
|
||||
@echo ""
|
||||
@echo "All code formatted!"
|
||||
|
||||
# ============================================================================
|
||||
# Validation (lint + type-check + test)
|
||||
# ============================================================================
|
||||
@@ -175,6 +207,9 @@ validate:
|
||||
@echo "Validating Knowledge Base..."
|
||||
@cd mcp-servers/knowledge-base && make validate
|
||||
@echo ""
|
||||
@echo "Validating Git Operations..."
|
||||
@cd mcp-servers/git-ops && make validate
|
||||
@echo ""
|
||||
@echo "All validations passed!"
|
||||
|
||||
validate-all: validate
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
"""Add category and display fields to agent_types table
|
||||
|
||||
Revision ID: 0007
|
||||
Revises: 0006
|
||||
Create Date: 2026-01-06
|
||||
|
||||
This migration adds:
|
||||
- category: String(50) for grouping agents by role type
|
||||
- icon: String(50) for Lucide icon identifier
|
||||
- color: String(7) for hex color code
|
||||
- sort_order: Integer for display ordering within categories
|
||||
- typical_tasks: JSONB list of tasks this agent excels at
|
||||
- collaboration_hints: JSONB list of agent slugs that work well together
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "0007"
|
||||
down_revision: str | None = "0006"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add category and display fields to agent_types table."""
|
||||
# Add new columns
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column("category", sa.String(length=50), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column("icon", sa.String(length=50), nullable=True, server_default="bot"),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column(
|
||||
"color", sa.String(length=7), nullable=True, server_default="#3B82F6"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column("sort_order", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column(
|
||||
"typical_tasks",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"agent_types",
|
||||
sa.Column(
|
||||
"collaboration_hints",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default="[]",
|
||||
),
|
||||
)
|
||||
|
||||
# Add indexes for category and sort_order
|
||||
op.create_index("ix_agent_types_category", "agent_types", ["category"])
|
||||
op.create_index("ix_agent_types_sort_order", "agent_types", ["sort_order"])
|
||||
op.create_index(
|
||||
"ix_agent_types_category_sort", "agent_types", ["category", "sort_order"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove category and display fields from agent_types table."""
|
||||
# Drop indexes
|
||||
op.drop_index("ix_agent_types_category_sort", table_name="agent_types")
|
||||
op.drop_index("ix_agent_types_sort_order", table_name="agent_types")
|
||||
op.drop_index("ix_agent_types_category", table_name="agent_types")
|
||||
|
||||
# Drop columns
|
||||
op.drop_column("agent_types", "collaboration_hints")
|
||||
op.drop_column("agent_types", "typical_tasks")
|
||||
op.drop_column("agent_types", "sort_order")
|
||||
op.drop_column("agent_types", "color")
|
||||
op.drop_column("agent_types", "icon")
|
||||
op.drop_column("agent_types", "category")
|
||||
@@ -81,6 +81,13 @@ def _build_agent_type_response(
|
||||
mcp_servers=agent_type.mcp_servers,
|
||||
tool_permissions=agent_type.tool_permissions,
|
||||
is_active=agent_type.is_active,
|
||||
# Category and display fields
|
||||
category=agent_type.category,
|
||||
icon=agent_type.icon,
|
||||
color=agent_type.color,
|
||||
sort_order=agent_type.sort_order,
|
||||
typical_tasks=agent_type.typical_tasks or [],
|
||||
collaboration_hints=agent_type.collaboration_hints or [],
|
||||
created_at=agent_type.created_at,
|
||||
updated_at=agent_type.updated_at,
|
||||
instance_count=instance_count,
|
||||
@@ -300,6 +307,7 @@ async def list_agent_types(
|
||||
request: Request,
|
||||
pagination: PaginationParams = Depends(),
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
category: str | None = Query(None, description="Filter by category"),
|
||||
search: str | None = Query(None, description="Search by name, slug, description"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -314,6 +322,7 @@ async def list_agent_types(
|
||||
request: FastAPI request object
|
||||
pagination: Pagination parameters (page, limit)
|
||||
is_active: Filter by active status (default: True)
|
||||
category: Filter by category (e.g., "development", "design")
|
||||
search: Optional search term for name, slug, description
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
@@ -328,6 +337,7 @@ async def list_agent_types(
|
||||
skip=pagination.offset,
|
||||
limit=pagination.limit,
|
||||
is_active=is_active,
|
||||
category=category,
|
||||
search=search,
|
||||
)
|
||||
|
||||
@@ -354,6 +364,51 @@ async def list_agent_types(
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/grouped",
|
||||
response_model=dict[str, list[AgentTypeResponse]],
|
||||
summary="List Agent Types Grouped by Category",
|
||||
description="Get all agent types organized by category",
|
||||
operation_id="list_agent_types_grouped",
|
||||
)
|
||||
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
|
||||
async def list_agent_types_grouped(
|
||||
request: Request,
|
||||
is_active: bool = Query(True, description="Filter by active status"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Any:
|
||||
"""
|
||||
Get agent types grouped by category.
|
||||
|
||||
Returns a dictionary where keys are category names and values
|
||||
are lists of agent types, sorted by sort_order within each category.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
is_active: Filter by active status (default: True)
|
||||
current_user: Authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category to list of agent types
|
||||
"""
|
||||
try:
|
||||
grouped = await agent_type_crud.get_grouped_by_category(db, is_active=is_active)
|
||||
|
||||
# Transform to response objects
|
||||
result: dict[str, list[AgentTypeResponse]] = {}
|
||||
for category, types in grouped.items():
|
||||
result[category] = [
|
||||
_build_agent_type_response(t, instance_count=0) for t in types
|
||||
]
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{agent_type_id}",
|
||||
response_model=AgentTypeResponse,
|
||||
|
||||
@@ -43,6 +43,13 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
mcp_servers=obj_in.mcp_servers,
|
||||
tool_permissions=obj_in.tool_permissions,
|
||||
is_active=obj_in.is_active,
|
||||
# Category and display fields
|
||||
category=obj_in.category.value if obj_in.category else None,
|
||||
icon=obj_in.icon,
|
||||
color=obj_in.color,
|
||||
sort_order=obj_in.sort_order,
|
||||
typical_tasks=obj_in.typical_tasks,
|
||||
collaboration_hints=obj_in.collaboration_hints,
|
||||
)
|
||||
db.add(db_obj)
|
||||
await db.commit()
|
||||
@@ -68,6 +75,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
category: str | None = None,
|
||||
search: str | None = None,
|
||||
sort_by: str = "created_at",
|
||||
sort_order: str = "desc",
|
||||
@@ -85,6 +93,9 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
if is_active is not None:
|
||||
query = query.where(AgentType.is_active == is_active)
|
||||
|
||||
if category:
|
||||
query = query.where(AgentType.category == category)
|
||||
|
||||
if search:
|
||||
search_filter = or_(
|
||||
AgentType.name.ilike(f"%{search}%"),
|
||||
@@ -162,6 +173,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: bool | None = None,
|
||||
category: str | None = None,
|
||||
search: str | None = None,
|
||||
) -> tuple[list[dict[str, Any]], int]:
|
||||
"""
|
||||
@@ -177,6 +189,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
category=category,
|
||||
search=search,
|
||||
)
|
||||
|
||||
@@ -260,6 +273,44 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_grouped_by_category(
|
||||
self,
|
||||
db: AsyncSession,
|
||||
*,
|
||||
is_active: bool = True,
|
||||
) -> dict[str, list[AgentType]]:
|
||||
"""
|
||||
Get agent types grouped by category, sorted by sort_order within each group.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
is_active: Filter by active status (default: True)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping category to list of agent types
|
||||
"""
|
||||
try:
|
||||
query = (
|
||||
select(AgentType)
|
||||
.where(AgentType.is_active == is_active)
|
||||
.order_by(AgentType.category, AgentType.sort_order, AgentType.name)
|
||||
)
|
||||
result = await db.execute(query)
|
||||
agent_types = list(result.scalars().all())
|
||||
|
||||
# Group by category
|
||||
grouped: dict[str, list[AgentType]] = {}
|
||||
for at in agent_types:
|
||||
cat: str = str(at.category) if at.category else "uncategorized"
|
||||
if cat not in grouped:
|
||||
grouped[cat] = []
|
||||
grouped[cat].append(at)
|
||||
|
||||
return grouped
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# Create a singleton instance for use across the application
|
||||
agent_type = CRUDAgentType(AgentType)
|
||||
|
||||
@@ -149,6 +149,13 @@ async def load_default_agent_types(session: AsyncSession) -> None:
|
||||
mcp_servers=agent_type_data.get("mcp_servers", []),
|
||||
tool_permissions=agent_type_data.get("tool_permissions", {}),
|
||||
is_active=agent_type_data.get("is_active", True),
|
||||
# Category and display fields
|
||||
category=agent_type_data.get("category"),
|
||||
icon=agent_type_data.get("icon", "bot"),
|
||||
color=agent_type_data.get("color", "#3B82F6"),
|
||||
sort_order=agent_type_data.get("sort_order", 0),
|
||||
typical_tasks=agent_type_data.get("typical_tasks", []),
|
||||
collaboration_hints=agent_type_data.get("collaboration_hints", []),
|
||||
)
|
||||
|
||||
await agent_type_crud.create(session, obj_in=agent_type_in)
|
||||
|
||||
@@ -6,7 +6,7 @@ An AgentType is a template that defines the capabilities, personality,
|
||||
and model configuration for agent instances.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Boolean, Column, Index, String, Text
|
||||
from sqlalchemy import Boolean, Column, Index, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -56,6 +56,24 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
|
||||
# Whether this agent type is available for new instances
|
||||
is_active = Column(Boolean, default=True, nullable=False, index=True)
|
||||
|
||||
# Category for grouping agents (development, design, quality, etc.)
|
||||
category = Column(String(50), nullable=True, index=True)
|
||||
|
||||
# Lucide icon identifier for UI display (e.g., "code", "palette", "shield")
|
||||
icon = Column(String(50), nullable=True, default="bot")
|
||||
|
||||
# Hex color code for visual distinction (e.g., "#3B82F6")
|
||||
color = Column(String(7), nullable=True, default="#3B82F6")
|
||||
|
||||
# Display ordering within category (lower = first)
|
||||
sort_order = Column(Integer, nullable=False, default=0, index=True)
|
||||
|
||||
# List of typical tasks this agent excels at
|
||||
typical_tasks = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# List of agent slugs that collaborate well with this type
|
||||
collaboration_hints = Column(JSONB, default=list, nullable=False)
|
||||
|
||||
# Relationships
|
||||
instances = relationship(
|
||||
"AgentInstance",
|
||||
@@ -66,6 +84,7 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
|
||||
__table_args__ = (
|
||||
Index("ix_agent_types_slug_active", "slug", "is_active"),
|
||||
Index("ix_agent_types_name_active", "name", "is_active"),
|
||||
Index("ix_agent_types_category_sort", "category", "sort_order"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -167,3 +167,29 @@ class SprintStatus(str, PyEnum):
|
||||
IN_REVIEW = "in_review"
|
||||
COMPLETED = "completed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class AgentTypeCategory(str, PyEnum):
|
||||
"""
|
||||
Category classification for agent types.
|
||||
|
||||
Used for grouping and filtering agents in the UI.
|
||||
|
||||
DEVELOPMENT: Product, project, and engineering roles
|
||||
DESIGN: UI/UX and design research roles
|
||||
QUALITY: QA and security engineering
|
||||
OPERATIONS: DevOps and MLOps
|
||||
AI_ML: Machine learning and AI specialists
|
||||
DATA: Data science and engineering
|
||||
LEADERSHIP: Technical leadership roles
|
||||
DOMAIN_EXPERT: Industry and domain specialists
|
||||
"""
|
||||
|
||||
DEVELOPMENT = "development"
|
||||
DESIGN = "design"
|
||||
QUALITY = "quality"
|
||||
OPERATIONS = "operations"
|
||||
AI_ML = "ai_ml"
|
||||
DATA = "data"
|
||||
LEADERSHIP = "leadership"
|
||||
DOMAIN_EXPERT = "domain_expert"
|
||||
|
||||
@@ -10,6 +10,8 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from app.models.syndarix.enums import AgentTypeCategory
|
||||
|
||||
|
||||
class AgentTypeBase(BaseModel):
|
||||
"""Base agent type schema with common fields."""
|
||||
@@ -26,6 +28,14 @@ class AgentTypeBase(BaseModel):
|
||||
tool_permissions: dict[str, Any] = Field(default_factory=dict)
|
||||
is_active: bool = True
|
||||
|
||||
# Category and display fields
|
||||
category: AgentTypeCategory | None = None
|
||||
icon: str | None = Field(None, max_length=50)
|
||||
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
|
||||
sort_order: int = Field(default=0, ge=0, le=1000)
|
||||
typical_tasks: list[str] = Field(default_factory=list)
|
||||
collaboration_hints: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
@@ -62,6 +72,18 @@ class AgentTypeBase(BaseModel):
|
||||
"""Validate MCP server list."""
|
||||
return [s.strip() for s in v if s.strip()]
|
||||
|
||||
@field_validator("typical_tasks")
|
||||
@classmethod
|
||||
def validate_typical_tasks(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize typical tasks list."""
|
||||
return [t.strip() for t in v if t.strip()]
|
||||
|
||||
@field_validator("collaboration_hints")
|
||||
@classmethod
|
||||
def validate_collaboration_hints(cls, v: list[str]) -> list[str]:
|
||||
"""Validate and normalize collaboration hints (agent slugs)."""
|
||||
return [h.strip().lower() for h in v if h.strip()]
|
||||
|
||||
|
||||
class AgentTypeCreate(AgentTypeBase):
|
||||
"""Schema for creating a new agent type."""
|
||||
@@ -87,6 +109,14 @@ class AgentTypeUpdate(BaseModel):
|
||||
tool_permissions: dict[str, Any] | None = None
|
||||
is_active: bool | None = None
|
||||
|
||||
# Category and display fields (all optional for updates)
|
||||
category: AgentTypeCategory | None = None
|
||||
icon: str | None = Field(None, max_length=50)
|
||||
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
|
||||
sort_order: int | None = Field(None, ge=0, le=1000)
|
||||
typical_tasks: list[str] | None = None
|
||||
collaboration_hints: list[str] | None = None
|
||||
|
||||
@field_validator("slug")
|
||||
@classmethod
|
||||
def validate_slug(cls, v: str | None) -> str | None:
|
||||
@@ -119,6 +149,22 @@ class AgentTypeUpdate(BaseModel):
|
||||
return v
|
||||
return [e.strip().lower() for e in v if e.strip()]
|
||||
|
||||
@field_validator("typical_tasks")
|
||||
@classmethod
|
||||
def validate_typical_tasks(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize typical tasks list."""
|
||||
if v is None:
|
||||
return v
|
||||
return [t.strip() for t in v if t.strip()]
|
||||
|
||||
@field_validator("collaboration_hints")
|
||||
@classmethod
|
||||
def validate_collaboration_hints(cls, v: list[str] | None) -> list[str] | None:
|
||||
"""Validate and normalize collaboration hints (agent slugs)."""
|
||||
if v is None:
|
||||
return v
|
||||
return [h.strip().lower() for h in v if h.strip()]
|
||||
|
||||
|
||||
class AgentTypeInDB(AgentTypeBase):
|
||||
"""Schema for agent type in database."""
|
||||
|
||||
@@ -29,7 +29,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:delete_*"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "development",
|
||||
"icon": "clipboard-check",
|
||||
"color": "#3B82F6",
|
||||
"sort_order": 10,
|
||||
"typical_tasks": ["Requirements discovery", "User story creation", "Backlog prioritization", "Stakeholder alignment"],
|
||||
"collaboration_hints": ["business-analyst", "solutions-architect", "scrum-master"]
|
||||
},
|
||||
{
|
||||
"name": "Project Manager",
|
||||
@@ -61,7 +67,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "development",
|
||||
"icon": "briefcase",
|
||||
"color": "#3B82F6",
|
||||
"sort_order": 20,
|
||||
"typical_tasks": ["Sprint planning", "Risk management", "Status reporting", "Team coordination"],
|
||||
"collaboration_hints": ["product-owner", "scrum-master", "technical-lead"]
|
||||
},
|
||||
{
|
||||
"name": "Business Analyst",
|
||||
@@ -93,7 +105,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "development",
|
||||
"icon": "file-text",
|
||||
"color": "#3B82F6",
|
||||
"sort_order": 20,
|
||||
"typical_tasks": ["Requirements analysis", "Process modeling", "Gap analysis", "Functional specifications"],
|
||||
"collaboration_hints": ["product-owner", "solutions-architect", "qa-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Solutions Architect",
|
||||
@@ -129,7 +147,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "development",
|
||||
"icon": "git-branch",
|
||||
"color": "#3B82F6",
|
||||
"sort_order": 20,
|
||||
"typical_tasks": ["System design", "ADR creation", "Technology selection", "Integration patterns"],
|
||||
"collaboration_hints": ["backend-engineer", "frontend-engineer", "security-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Full Stack Engineer",
|
||||
@@ -166,7 +190,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request", "gitea:delete_*"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "development",
|
||||
"icon": "code",
|
||||
"color": "#3B82F6",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["End-to-end feature development", "API design", "UI implementation", "Database operations"],
|
||||
"collaboration_hints": ["solutions-architect", "qa-engineer", "devops-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Backend Engineer",
|
||||
@@ -208,7 +238,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request", "gitea:delete_*"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "development",
|
||||
"icon": "server",
|
||||
"color": "#3B82F6",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["API development", "Database optimization", "System integration", "Performance tuning"],
|
||||
"collaboration_hints": ["solutions-architect", "frontend-engineer", "data-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Frontend Engineer",
|
||||
@@ -249,7 +285,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request", "gitea:delete_*"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "development",
|
||||
"icon": "layout",
|
||||
"color": "#3B82F6",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["UI component development", "State management", "API integration", "Responsive design"],
|
||||
"collaboration_hints": ["ui-ux-designer", "backend-engineer", "qa-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Mobile Engineer",
|
||||
@@ -286,7 +328,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request", "gitea:delete_*"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "development",
|
||||
"icon": "smartphone",
|
||||
"color": "#3B82F6",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["Native app development", "Cross-platform solutions", "Mobile optimization", "App store deployment"],
|
||||
"collaboration_hints": ["backend-engineer", "ui-ux-designer", "qa-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "UI/UX Designer",
|
||||
@@ -321,7 +369,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "design",
|
||||
"icon": "palette",
|
||||
"color": "#EC4899",
|
||||
"sort_order": 20,
|
||||
"typical_tasks": ["Interface design", "User flow creation", "Design system maintenance", "Prototyping"],
|
||||
"collaboration_hints": ["frontend-engineer", "ux-researcher", "product-owner"]
|
||||
},
|
||||
{
|
||||
"name": "UX Researcher",
|
||||
@@ -355,7 +409,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "design",
|
||||
"icon": "search",
|
||||
"color": "#EC4899",
|
||||
"sort_order": 20,
|
||||
"typical_tasks": ["User research", "Usability testing", "Journey mapping", "Research synthesis"],
|
||||
"collaboration_hints": ["ui-ux-designer", "product-owner", "business-analyst"]
|
||||
},
|
||||
{
|
||||
"name": "QA Engineer",
|
||||
@@ -391,7 +451,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "quality",
|
||||
"icon": "shield",
|
||||
"color": "#10B981",
|
||||
"sort_order": 20,
|
||||
"typical_tasks": ["Test strategy development", "Test automation", "Bug verification", "Quality metrics"],
|
||||
"collaboration_hints": ["backend-engineer", "frontend-engineer", "devops-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "DevOps Engineer",
|
||||
@@ -431,7 +497,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_release", "gitea:delete_*"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "operations",
|
||||
"icon": "settings",
|
||||
"color": "#F59E0B",
|
||||
"sort_order": 10,
|
||||
"typical_tasks": ["CI/CD pipeline design", "Infrastructure automation", "Monitoring setup", "Deployment optimization"],
|
||||
"collaboration_hints": ["backend-engineer", "security-engineer", "mlops-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Security Engineer",
|
||||
@@ -467,7 +539,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "quality",
|
||||
"icon": "shield-check",
|
||||
"color": "#10B981",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["Security architecture", "Vulnerability assessment", "Compliance validation", "Threat modeling"],
|
||||
"collaboration_hints": ["solutions-architect", "devops-engineer", "backend-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "AI/ML Engineer",
|
||||
@@ -503,7 +581,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "ai_ml",
|
||||
"icon": "brain",
|
||||
"color": "#8B5CF6",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["Model development", "Algorithm selection", "Feature engineering", "Model optimization"],
|
||||
"collaboration_hints": ["data-scientist", "mlops-engineer", "backend-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "AI Researcher",
|
||||
@@ -537,7 +621,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "ai_ml",
|
||||
"icon": "microscope",
|
||||
"color": "#8B5CF6",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["Research paper analysis", "Novel algorithm design", "Experiment design", "Benchmark evaluation"],
|
||||
"collaboration_hints": ["ai-ml-engineer", "data-scientist", "scientific-computing-expert"]
|
||||
},
|
||||
{
|
||||
"name": "Computer Vision Engineer",
|
||||
@@ -573,7 +663,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "ai_ml",
|
||||
"icon": "eye",
|
||||
"color": "#8B5CF6",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["Image processing pipelines", "Object detection models", "Video analysis", "Computer vision deployment"],
|
||||
"collaboration_hints": ["ai-ml-engineer", "mlops-engineer", "backend-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "NLP Engineer",
|
||||
@@ -609,7 +705,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "ai_ml",
|
||||
"icon": "message-square",
|
||||
"color": "#8B5CF6",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["Text processing pipelines", "Language model fine-tuning", "Named entity recognition", "Sentiment analysis"],
|
||||
"collaboration_hints": ["ai-ml-engineer", "data-scientist", "backend-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "MLOps Engineer",
|
||||
@@ -645,7 +747,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_release"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "operations",
|
||||
"icon": "settings-2",
|
||||
"color": "#F59E0B",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["ML pipeline development", "Model deployment", "Feature store management", "Model monitoring"],
|
||||
"collaboration_hints": ["ai-ml-engineer", "devops-engineer", "data-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Data Scientist",
|
||||
@@ -681,7 +789,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "data",
|
||||
"icon": "chart-bar",
|
||||
"color": "#06B6D4",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["Statistical analysis", "Predictive modeling", "Data visualization", "Insight generation"],
|
||||
"collaboration_hints": ["data-engineer", "ai-ml-engineer", "business-analyst"]
|
||||
},
|
||||
{
|
||||
"name": "Data Engineer",
|
||||
@@ -717,7 +831,13 @@
|
||||
"denied": [],
|
||||
"require_approval": ["gitea:create_pull_request"]
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "data",
|
||||
"icon": "database",
|
||||
"color": "#06B6D4",
|
||||
"sort_order": 30,
|
||||
"typical_tasks": ["Data pipeline development", "ETL optimization", "Data warehouse design", "Data quality management"],
|
||||
"collaboration_hints": ["data-scientist", "backend-engineer", "mlops-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Technical Lead",
|
||||
@@ -749,7 +869,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "leadership",
|
||||
"icon": "users",
|
||||
"color": "#F97316",
|
||||
"sort_order": 10,
|
||||
"typical_tasks": ["Technical direction", "Code review leadership", "Team mentoring", "Architecture decisions"],
|
||||
"collaboration_hints": ["solutions-architect", "backend-engineer", "frontend-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Scrum Master",
|
||||
@@ -781,7 +907,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "leadership",
|
||||
"icon": "target",
|
||||
"color": "#F97316",
|
||||
"sort_order": 10,
|
||||
"typical_tasks": ["Sprint facilitation", "Impediment removal", "Process improvement", "Team coaching"],
|
||||
"collaboration_hints": ["project-manager", "product-owner", "technical-lead"]
|
||||
},
|
||||
{
|
||||
"name": "Financial Systems Expert",
|
||||
@@ -816,7 +948,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "domain_expert",
|
||||
"icon": "calculator",
|
||||
"color": "#84CC16",
|
||||
"sort_order": 10,
|
||||
"typical_tasks": ["Financial system design", "Regulatory compliance", "Transaction processing", "Audit trail implementation"],
|
||||
"collaboration_hints": ["solutions-architect", "security-engineer", "backend-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Healthcare Systems Expert",
|
||||
@@ -850,7 +988,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "domain_expert",
|
||||
"icon": "heart-pulse",
|
||||
"color": "#84CC16",
|
||||
"sort_order": 50,
|
||||
"typical_tasks": ["Healthcare system design", "HIPAA compliance", "HL7/FHIR integration", "Clinical workflow optimization"],
|
||||
"collaboration_hints": ["solutions-architect", "security-engineer", "data-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Scientific Computing Expert",
|
||||
@@ -886,7 +1030,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "domain_expert",
|
||||
"icon": "flask",
|
||||
"color": "#84CC16",
|
||||
"sort_order": 50,
|
||||
"typical_tasks": ["HPC architecture", "Scientific algorithm implementation", "Data pipeline optimization", "Numerical computing"],
|
||||
"collaboration_hints": ["ai-researcher", "data-scientist", "backend-engineer"]
|
||||
},
|
||||
{
|
||||
"name": "Behavioral Psychology Expert",
|
||||
@@ -919,7 +1069,13 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "domain_expert",
|
||||
"icon": "lightbulb",
|
||||
"color": "#84CC16",
|
||||
"sort_order": 50,
|
||||
"typical_tasks": ["Behavioral design", "Engagement optimization", "User motivation analysis", "Ethical AI guidelines"],
|
||||
"collaboration_hints": ["ux-researcher", "ui-ux-designer", "product-owner"]
|
||||
},
|
||||
{
|
||||
"name": "Technical Writer",
|
||||
@@ -951,6 +1107,12 @@
|
||||
"denied": [],
|
||||
"require_approval": []
|
||||
},
|
||||
"is_active": true
|
||||
"is_active": true,
|
||||
"category": "domain_expert",
|
||||
"icon": "book-open",
|
||||
"color": "#84CC16",
|
||||
"sort_order": 50,
|
||||
"typical_tasks": ["API documentation", "User guides", "Technical specifications", "Knowledge base creation"],
|
||||
"collaboration_hints": ["solutions-architect", "product-owner", "qa-engineer"]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -26,6 +26,7 @@ Usage:
|
||||
# Inside Docker (without --local flag):
|
||||
python migrate.py auto "Add new field"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
@@ -44,13 +45,14 @@ def setup_database_url(use_local: bool) -> str:
|
||||
# Override DATABASE_URL to use localhost instead of Docker hostname
|
||||
local_url = os.environ.get(
|
||||
"LOCAL_DATABASE_URL",
|
||||
"postgresql://postgres:postgres@localhost:5432/app"
|
||||
"postgresql://postgres:postgres@localhost:5432/syndarix",
|
||||
)
|
||||
os.environ["DATABASE_URL"] = local_url
|
||||
return local_url
|
||||
|
||||
# Use the configured DATABASE_URL from environment/.env
|
||||
from app.core.config import settings
|
||||
|
||||
return settings.database_url
|
||||
|
||||
|
||||
@@ -61,6 +63,7 @@ def check_models():
|
||||
try:
|
||||
# Import all models through the models package
|
||||
from app.models import __all__ as all_models
|
||||
|
||||
print(f"Found {len(all_models)} model(s):")
|
||||
for model in all_models:
|
||||
print(f" - {model}")
|
||||
@@ -110,7 +113,9 @@ def generate_migration(message, rev_id=None, auto_rev_id=True, offline=False):
|
||||
# Look for the revision ID, which is typically 12 hex characters
|
||||
parts = line.split()
|
||||
for part in parts:
|
||||
if len(part) >= 12 and all(c in "0123456789abcdef" for c in part[:12]):
|
||||
if len(part) >= 12 and all(
|
||||
c in "0123456789abcdef" for c in part[:12]
|
||||
):
|
||||
revision = part[:12]
|
||||
break
|
||||
except Exception as e:
|
||||
@@ -185,6 +190,7 @@ def check_database_connection():
|
||||
db_url = os.environ.get("DATABASE_URL")
|
||||
if not db_url:
|
||||
from app.core.config import settings
|
||||
|
||||
db_url = settings.database_url
|
||||
|
||||
engine = create_engine(db_url)
|
||||
@@ -270,8 +276,8 @@ def generate_offline_migration(message, rev_id):
|
||||
content = f'''"""{message}
|
||||
|
||||
Revision ID: {rev_id}
|
||||
Revises: {down_revision or ''}
|
||||
Create Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}
|
||||
Revises: {down_revision or ""}
|
||||
Create Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")}
|
||||
|
||||
"""
|
||||
|
||||
@@ -320,6 +326,7 @@ def reset_alembic_version():
|
||||
db_url = os.environ.get("DATABASE_URL")
|
||||
if not db_url:
|
||||
from app.core.config import settings
|
||||
|
||||
db_url = settings.database_url
|
||||
|
||||
try:
|
||||
@@ -338,82 +345,80 @@ def reset_alembic_version():
|
||||
def main():
|
||||
"""Main function"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Database migration helper for Generative Models Arena'
|
||||
description="Database migration helper for Generative Models Arena"
|
||||
)
|
||||
|
||||
# Global options
|
||||
parser.add_argument(
|
||||
'--local', '-l',
|
||||
action='store_true',
|
||||
help='Use localhost instead of Docker hostname (for local development)'
|
||||
"--local",
|
||||
"-l",
|
||||
action="store_true",
|
||||
help="Use localhost instead of Docker hostname (for local development)",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
# Generate command
|
||||
generate_parser = subparsers.add_parser('generate', help='Generate a migration')
|
||||
generate_parser.add_argument('message', help='Migration message')
|
||||
generate_parser = subparsers.add_parser("generate", help="Generate a migration")
|
||||
generate_parser.add_argument("message", help="Migration message")
|
||||
generate_parser.add_argument(
|
||||
'--rev-id',
|
||||
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
|
||||
"--rev-id", help="Custom revision ID (e.g., 0001, 0002 for sequential naming)"
|
||||
)
|
||||
generate_parser.add_argument(
|
||||
'--offline',
|
||||
action='store_true',
|
||||
help='Generate empty migration template without database connection'
|
||||
"--offline",
|
||||
action="store_true",
|
||||
help="Generate empty migration template without database connection",
|
||||
)
|
||||
|
||||
# Apply command
|
||||
apply_parser = subparsers.add_parser('apply', help='Apply migrations')
|
||||
apply_parser.add_argument('--revision', help='Specific revision to apply to')
|
||||
apply_parser = subparsers.add_parser("apply", help="Apply migrations")
|
||||
apply_parser.add_argument("--revision", help="Specific revision to apply to")
|
||||
|
||||
# List command
|
||||
subparsers.add_parser('list', help='List migrations')
|
||||
subparsers.add_parser("list", help="List migrations")
|
||||
|
||||
# Current command
|
||||
subparsers.add_parser('current', help='Show current revision')
|
||||
subparsers.add_parser("current", help="Show current revision")
|
||||
|
||||
# Check command
|
||||
subparsers.add_parser('check', help='Check database connection and models')
|
||||
subparsers.add_parser("check", help="Check database connection and models")
|
||||
|
||||
# Next command (show next revision ID)
|
||||
subparsers.add_parser('next', help='Show the next sequential revision ID')
|
||||
subparsers.add_parser("next", help="Show the next sequential revision ID")
|
||||
|
||||
# Reset command (clear alembic_version table)
|
||||
subparsers.add_parser(
|
||||
'reset',
|
||||
help='Reset alembic_version table (use after deleting all migrations)'
|
||||
"reset", help="Reset alembic_version table (use after deleting all migrations)"
|
||||
)
|
||||
|
||||
# Auto command (generate and apply)
|
||||
auto_parser = subparsers.add_parser('auto', help='Generate and apply migration')
|
||||
auto_parser.add_argument('message', help='Migration message')
|
||||
auto_parser = subparsers.add_parser("auto", help="Generate and apply migration")
|
||||
auto_parser.add_argument("message", help="Migration message")
|
||||
auto_parser.add_argument(
|
||||
'--rev-id',
|
||||
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
|
||||
"--rev-id", help="Custom revision ID (e.g., 0001, 0002 for sequential naming)"
|
||||
)
|
||||
auto_parser.add_argument(
|
||||
'--offline',
|
||||
action='store_true',
|
||||
help='Generate empty migration template without database connection'
|
||||
"--offline",
|
||||
action="store_true",
|
||||
help="Generate empty migration template without database connection",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Commands that don't need database connection
|
||||
if args.command == 'next':
|
||||
if args.command == "next":
|
||||
show_next_rev_id()
|
||||
return
|
||||
|
||||
# Check if offline mode is requested
|
||||
offline = getattr(args, 'offline', False)
|
||||
offline = getattr(args, "offline", False)
|
||||
|
||||
# Offline generate doesn't need database or model check
|
||||
if args.command == 'generate' and offline:
|
||||
if args.command == "generate" and offline:
|
||||
generate_migration(args.message, rev_id=args.rev_id, offline=True)
|
||||
return
|
||||
|
||||
if args.command == 'auto' and offline:
|
||||
if args.command == "auto" and offline:
|
||||
generate_migration(args.message, rev_id=args.rev_id, offline=True)
|
||||
print("\nOffline migration generated. Apply it later with:")
|
||||
print(" python migrate.py --local apply")
|
||||
@@ -423,27 +428,27 @@ def main():
|
||||
db_url = setup_database_url(args.local)
|
||||
print(f"Using database URL: {db_url}")
|
||||
|
||||
if args.command == 'generate':
|
||||
if args.command == "generate":
|
||||
check_models()
|
||||
generate_migration(args.message, rev_id=args.rev_id)
|
||||
|
||||
elif args.command == 'apply':
|
||||
elif args.command == "apply":
|
||||
apply_migration(args.revision)
|
||||
|
||||
elif args.command == 'list':
|
||||
elif args.command == "list":
|
||||
list_migrations()
|
||||
|
||||
elif args.command == 'current':
|
||||
elif args.command == "current":
|
||||
show_current()
|
||||
|
||||
elif args.command == 'check':
|
||||
elif args.command == "check":
|
||||
check_database_connection()
|
||||
check_models()
|
||||
|
||||
elif args.command == 'reset':
|
||||
elif args.command == "reset":
|
||||
reset_alembic_version()
|
||||
|
||||
elif args.command == 'auto':
|
||||
elif args.command == "auto":
|
||||
check_models()
|
||||
revision = generate_migration(args.message, rev_id=args.rev_id)
|
||||
if revision:
|
||||
|
||||
@@ -745,3 +745,230 @@ class TestAgentTypeInstanceCount:
|
||||
for agent_type in data["data"]:
|
||||
assert "instance_count" in agent_type
|
||||
assert isinstance(agent_type["instance_count"], int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentTypeCategoryFields:
|
||||
"""Tests for agent type category and display fields."""
|
||||
|
||||
async def test_create_agent_type_with_category_fields(
|
||||
self, client, superuser_token
|
||||
):
|
||||
"""Test creating agent type with all category and display fields."""
|
||||
unique_slug = f"category-type-{uuid.uuid4().hex[:8]}"
|
||||
response = await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": "Categorized Agent Type",
|
||||
"slug": unique_slug,
|
||||
"description": "An agent type with category fields",
|
||||
"expertise": ["python"],
|
||||
"personality_prompt": "You are a helpful assistant.",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
# Category and display fields
|
||||
"category": "development",
|
||||
"icon": "code",
|
||||
"color": "#3B82F6",
|
||||
"sort_order": 10,
|
||||
"typical_tasks": ["Write code", "Review PRs"],
|
||||
"collaboration_hints": ["backend-engineer", "qa-engineer"],
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
|
||||
assert data["category"] == "development"
|
||||
assert data["icon"] == "code"
|
||||
assert data["color"] == "#3B82F6"
|
||||
assert data["sort_order"] == 10
|
||||
assert data["typical_tasks"] == ["Write code", "Review PRs"]
|
||||
assert data["collaboration_hints"] == ["backend-engineer", "qa-engineer"]
|
||||
|
||||
async def test_create_agent_type_with_nullable_category(
|
||||
self, client, superuser_token
|
||||
):
|
||||
"""Test creating agent type with null category."""
|
||||
unique_slug = f"null-category-{uuid.uuid4().hex[:8]}"
|
||||
response = await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": "Uncategorized Agent",
|
||||
"slug": unique_slug,
|
||||
"expertise": ["general"],
|
||||
"personality_prompt": "You are a helpful assistant.",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"category": None,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data["category"] is None
|
||||
|
||||
async def test_create_agent_type_invalid_color_format(
|
||||
self, client, superuser_token
|
||||
):
|
||||
"""Test that invalid color format is rejected."""
|
||||
unique_slug = f"invalid-color-{uuid.uuid4().hex[:8]}"
|
||||
response = await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": "Invalid Color Agent",
|
||||
"slug": unique_slug,
|
||||
"expertise": ["python"],
|
||||
"personality_prompt": "You are a helpful assistant.",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"color": "not-a-hex-color",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
async def test_create_agent_type_invalid_category(self, client, superuser_token):
|
||||
"""Test that invalid category value is rejected."""
|
||||
unique_slug = f"invalid-category-{uuid.uuid4().hex[:8]}"
|
||||
response = await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": "Invalid Category Agent",
|
||||
"slug": unique_slug,
|
||||
"expertise": ["python"],
|
||||
"personality_prompt": "You are a helpful assistant.",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"category": "not_a_valid_category",
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
async def test_update_agent_type_category_fields(
|
||||
self, client, superuser_token, test_agent_type
|
||||
):
|
||||
"""Test updating category and display fields."""
|
||||
agent_type_id = test_agent_type["id"]
|
||||
|
||||
response = await client.patch(
|
||||
f"/api/v1/agent-types/{agent_type_id}",
|
||||
json={
|
||||
"category": "ai_ml",
|
||||
"icon": "brain",
|
||||
"color": "#8B5CF6",
|
||||
"sort_order": 50,
|
||||
"typical_tasks": ["Train models", "Analyze data"],
|
||||
"collaboration_hints": ["data-scientist"],
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
assert data["category"] == "ai_ml"
|
||||
assert data["icon"] == "brain"
|
||||
assert data["color"] == "#8B5CF6"
|
||||
assert data["sort_order"] == 50
|
||||
assert data["typical_tasks"] == ["Train models", "Analyze data"]
|
||||
assert data["collaboration_hints"] == ["data-scientist"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentTypeCategoryFilter:
|
||||
"""Tests for agent type category filtering."""
|
||||
|
||||
async def test_list_agent_types_filter_by_category(
|
||||
self, client, superuser_token, user_token
|
||||
):
|
||||
"""Test filtering agent types by category."""
|
||||
# Create agent types in different categories
|
||||
for cat in ["development", "design"]:
|
||||
unique_slug = f"filter-test-{cat}-{uuid.uuid4().hex[:8]}"
|
||||
await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": f"Filter Test {cat.capitalize()}",
|
||||
"slug": unique_slug,
|
||||
"expertise": ["python"],
|
||||
"personality_prompt": "Test prompt",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"category": cat,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Filter by development category
|
||||
response = await client.get(
|
||||
"/api/v1/agent-types",
|
||||
params={"category": "development"},
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
# All returned types should have development category
|
||||
for agent_type in data["data"]:
|
||||
assert agent_type["category"] == "development"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentTypeGroupedEndpoint:
|
||||
"""Tests for the grouped by category endpoint."""
|
||||
|
||||
async def test_list_agent_types_grouped(self, client, superuser_token, user_token):
|
||||
"""Test getting agent types grouped by category."""
|
||||
# Create agent types in different categories
|
||||
categories = ["development", "design", "quality"]
|
||||
for cat in categories:
|
||||
unique_slug = f"grouped-test-{cat}-{uuid.uuid4().hex[:8]}"
|
||||
await client.post(
|
||||
"/api/v1/agent-types",
|
||||
json={
|
||||
"name": f"Grouped Test {cat.capitalize()}",
|
||||
"slug": unique_slug,
|
||||
"expertise": ["python"],
|
||||
"personality_prompt": "Test prompt",
|
||||
"primary_model": "claude-opus-4-5-20251101",
|
||||
"category": cat,
|
||||
"sort_order": 10,
|
||||
},
|
||||
headers={"Authorization": f"Bearer {superuser_token}"},
|
||||
)
|
||||
|
||||
# Get grouped agent types
|
||||
response = await client.get(
|
||||
"/api/v1/agent-types/grouped",
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
|
||||
# Should be a dict with category keys
|
||||
assert isinstance(data, dict)
|
||||
|
||||
# Check that at least one of our created categories exists
|
||||
assert any(cat in data for cat in categories)
|
||||
|
||||
async def test_list_agent_types_grouped_filter_inactive(
|
||||
self, client, superuser_token, user_token
|
||||
):
|
||||
"""Test grouped endpoint with is_active filter."""
|
||||
response = await client.get(
|
||||
"/api/v1/agent-types/grouped",
|
||||
params={"is_active": False},
|
||||
headers={"Authorization": f"Bearer {user_token}"},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert isinstance(data, dict)
|
||||
|
||||
async def test_list_agent_types_grouped_unauthenticated(self, client):
|
||||
"""Test that unauthenticated users cannot access grouped endpoint."""
|
||||
response = await client.get("/api/v1/agent-types/grouped")
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@@ -368,3 +368,9 @@ async def e2e_org_with_members(e2e_client, e2e_superuser):
|
||||
"user_id": member_id,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# NOTE: Class-scoped fixtures for E2E tests were attempted but have fundamental
|
||||
# issues with pytest-asyncio + SQLAlchemy/asyncpg event loop management.
|
||||
# The function-scoped fixtures above provide proper test isolation.
|
||||
# Performance optimization would require significant infrastructure changes.
|
||||
|
||||
@@ -316,3 +316,325 @@ class TestAgentTypeJsonFields:
|
||||
)
|
||||
|
||||
assert agent_type.fallback_models == models
|
||||
|
||||
|
||||
class TestAgentTypeCategoryFieldsValidation:
|
||||
"""Tests for AgentType category and display field validation."""
|
||||
|
||||
def test_valid_category_values(self):
|
||||
"""Test that all valid category values are accepted."""
|
||||
valid_categories = [
|
||||
"development",
|
||||
"design",
|
||||
"quality",
|
||||
"operations",
|
||||
"ai_ml",
|
||||
"data",
|
||||
"leadership",
|
||||
"domain_expert",
|
||||
]
|
||||
|
||||
for category in valid_categories:
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
category=category,
|
||||
)
|
||||
assert agent_type.category.value == category
|
||||
|
||||
def test_category_null_allowed(self):
|
||||
"""Test that null category is allowed."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
category=None,
|
||||
)
|
||||
assert agent_type.category is None
|
||||
|
||||
def test_invalid_category_rejected(self):
|
||||
"""Test that invalid category values are rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
category="invalid_category",
|
||||
)
|
||||
|
||||
def test_valid_hex_color(self):
|
||||
"""Test that valid hex colors are accepted."""
|
||||
valid_colors = ["#3B82F6", "#EC4899", "#10B981", "#ffffff", "#000000"]
|
||||
|
||||
for color in valid_colors:
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
color=color,
|
||||
)
|
||||
assert agent_type.color == color
|
||||
|
||||
def test_invalid_hex_color_rejected(self):
|
||||
"""Test that invalid hex colors are rejected."""
|
||||
invalid_colors = [
|
||||
"not-a-color",
|
||||
"3B82F6", # Missing #
|
||||
"#3B82F", # Too short
|
||||
"#3B82F6A", # Too long
|
||||
"#GGGGGG", # Invalid hex chars
|
||||
"rgb(59, 130, 246)", # RGB format not supported
|
||||
]
|
||||
|
||||
for color in invalid_colors:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
color=color,
|
||||
)
|
||||
|
||||
def test_color_null_allowed(self):
|
||||
"""Test that null color is allowed."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
color=None,
|
||||
)
|
||||
assert agent_type.color is None
|
||||
|
||||
def test_sort_order_valid_range(self):
|
||||
"""Test that valid sort_order values are accepted."""
|
||||
for sort_order in [0, 1, 500, 1000]:
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
sort_order=sort_order,
|
||||
)
|
||||
assert agent_type.sort_order == sort_order
|
||||
|
||||
def test_sort_order_default_zero(self):
|
||||
"""Test that sort_order defaults to 0."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
assert agent_type.sort_order == 0
|
||||
|
||||
def test_sort_order_negative_rejected(self):
|
||||
"""Test that negative sort_order is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
sort_order=-1,
|
||||
)
|
||||
|
||||
def test_sort_order_exceeds_max_rejected(self):
|
||||
"""Test that sort_order > 1000 is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
sort_order=1001,
|
||||
)
|
||||
|
||||
def test_icon_max_length(self):
|
||||
"""Test that icon field respects max length."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
icon="x" * 50,
|
||||
)
|
||||
assert len(agent_type.icon) == 50
|
||||
|
||||
def test_icon_exceeds_max_length_rejected(self):
|
||||
"""Test that icon exceeding max length is rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
icon="x" * 51,
|
||||
)
|
||||
|
||||
|
||||
class TestAgentTypeTypicalTasksValidation:
|
||||
"""Tests for typical_tasks field validation."""
|
||||
|
||||
def test_typical_tasks_list(self):
|
||||
"""Test typical_tasks as a list."""
|
||||
tasks = ["Write code", "Review PRs", "Debug issues"]
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
typical_tasks=tasks,
|
||||
)
|
||||
assert agent_type.typical_tasks == tasks
|
||||
|
||||
def test_typical_tasks_default_empty(self):
|
||||
"""Test typical_tasks defaults to empty list."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
assert agent_type.typical_tasks == []
|
||||
|
||||
def test_typical_tasks_strips_whitespace(self):
|
||||
"""Test that typical_tasks items are stripped."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
typical_tasks=[" Write code ", " Debug "],
|
||||
)
|
||||
assert agent_type.typical_tasks == ["Write code", "Debug"]
|
||||
|
||||
def test_typical_tasks_removes_empty_strings(self):
|
||||
"""Test that empty strings are removed from typical_tasks."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
typical_tasks=["Write code", "", " ", "Debug"],
|
||||
)
|
||||
assert agent_type.typical_tasks == ["Write code", "Debug"]
|
||||
|
||||
|
||||
class TestAgentTypeCollaborationHintsValidation:
|
||||
"""Tests for collaboration_hints field validation."""
|
||||
|
||||
def test_collaboration_hints_list(self):
|
||||
"""Test collaboration_hints as a list."""
|
||||
hints = ["backend-engineer", "qa-engineer"]
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
collaboration_hints=hints,
|
||||
)
|
||||
assert agent_type.collaboration_hints == hints
|
||||
|
||||
def test_collaboration_hints_default_empty(self):
|
||||
"""Test collaboration_hints defaults to empty list."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
)
|
||||
assert agent_type.collaboration_hints == []
|
||||
|
||||
def test_collaboration_hints_normalized_lowercase(self):
|
||||
"""Test that collaboration_hints are normalized to lowercase."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
collaboration_hints=["Backend-Engineer", "QA-ENGINEER"],
|
||||
)
|
||||
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
|
||||
|
||||
def test_collaboration_hints_strips_whitespace(self):
|
||||
"""Test that collaboration_hints are stripped."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
collaboration_hints=[" backend-engineer ", " qa-engineer "],
|
||||
)
|
||||
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
|
||||
|
||||
def test_collaboration_hints_removes_empty_strings(self):
|
||||
"""Test that empty strings are removed from collaboration_hints."""
|
||||
agent_type = AgentTypeCreate(
|
||||
name="Test Agent",
|
||||
slug="test-agent",
|
||||
personality_prompt="Test",
|
||||
primary_model="claude-opus-4-5-20251101",
|
||||
collaboration_hints=["backend-engineer", "", " ", "qa-engineer"],
|
||||
)
|
||||
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
|
||||
|
||||
|
||||
class TestAgentTypeUpdateCategoryFields:
|
||||
"""Tests for AgentTypeUpdate category and display fields."""
|
||||
|
||||
def test_update_category_field(self):
|
||||
"""Test updating category field."""
|
||||
update = AgentTypeUpdate(category="ai_ml")
|
||||
assert update.category.value == "ai_ml"
|
||||
|
||||
def test_update_icon_field(self):
|
||||
"""Test updating icon field."""
|
||||
update = AgentTypeUpdate(icon="brain")
|
||||
assert update.icon == "brain"
|
||||
|
||||
def test_update_color_field(self):
|
||||
"""Test updating color field."""
|
||||
update = AgentTypeUpdate(color="#8B5CF6")
|
||||
assert update.color == "#8B5CF6"
|
||||
|
||||
def test_update_sort_order_field(self):
|
||||
"""Test updating sort_order field."""
|
||||
update = AgentTypeUpdate(sort_order=50)
|
||||
assert update.sort_order == 50
|
||||
|
||||
def test_update_typical_tasks_field(self):
|
||||
"""Test updating typical_tasks field."""
|
||||
update = AgentTypeUpdate(typical_tasks=["New task"])
|
||||
assert update.typical_tasks == ["New task"]
|
||||
|
||||
def test_update_typical_tasks_strips_whitespace(self):
|
||||
"""Test that typical_tasks are stripped on update."""
|
||||
update = AgentTypeUpdate(typical_tasks=[" New task "])
|
||||
assert update.typical_tasks == ["New task"]
|
||||
|
||||
def test_update_collaboration_hints_field(self):
|
||||
"""Test updating collaboration_hints field."""
|
||||
update = AgentTypeUpdate(collaboration_hints=["new-collaborator"])
|
||||
assert update.collaboration_hints == ["new-collaborator"]
|
||||
|
||||
def test_update_collaboration_hints_normalized(self):
|
||||
"""Test that collaboration_hints are normalized on update."""
|
||||
update = AgentTypeUpdate(collaboration_hints=[" New-Collaborator "])
|
||||
assert update.collaboration_hints == ["new-collaborator"]
|
||||
|
||||
def test_update_invalid_color_rejected(self):
|
||||
"""Test that invalid color is rejected on update."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeUpdate(color="invalid")
|
||||
|
||||
def test_update_invalid_sort_order_rejected(self):
|
||||
"""Test that invalid sort_order is rejected on update."""
|
||||
with pytest.raises(ValidationError):
|
||||
AgentTypeUpdate(sort_order=-1)
|
||||
|
||||
@@ -42,6 +42,9 @@ class TestInitDb:
|
||||
assert user.last_name == "User"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(
|
||||
reason="SQLite doesn't support UUID type binding - requires PostgreSQL"
|
||||
)
|
||||
async def test_init_db_returns_existing_superuser(
|
||||
self, async_test_db, async_test_user
|
||||
):
|
||||
|
||||
@@ -96,6 +96,38 @@ services:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
|
||||
mcp-git-ops:
|
||||
build:
|
||||
context: ./mcp-servers/git-ops
|
||||
dockerfile: Dockerfile
|
||||
ports:
|
||||
- "8003:8003"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
# GIT_OPS_ prefix required by pydantic-settings config
|
||||
- GIT_OPS_HOST=0.0.0.0
|
||||
- GIT_OPS_PORT=8003
|
||||
- GIT_OPS_REDIS_URL=redis://redis:6379/3
|
||||
- GIT_OPS_GITEA_BASE_URL=${GITEA_BASE_URL}
|
||||
- GIT_OPS_GITEA_TOKEN=${GITEA_TOKEN}
|
||||
- GIT_OPS_GITHUB_TOKEN=${GITHUB_TOKEN}
|
||||
- ENVIRONMENT=development
|
||||
volumes:
|
||||
- git_workspaces_dev:/workspaces
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
networks:
|
||||
- app-network
|
||||
restart: unless-stopped
|
||||
|
||||
backend:
|
||||
build:
|
||||
context: ./backend
|
||||
@@ -119,6 +151,7 @@ services:
|
||||
# MCP Server URLs
|
||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||
- GIT_OPS_URL=http://mcp-git-ops:8003
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
@@ -128,6 +161,8 @@ services:
|
||||
condition: service_healthy
|
||||
mcp-knowledge-base:
|
||||
condition: service_healthy
|
||||
mcp-git-ops:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 10s
|
||||
@@ -155,6 +190,7 @@ services:
|
||||
# MCP Server URLs (agents need access to MCP)
|
||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||
- GIT_OPS_URL=http://mcp-git-ops:8003
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
@@ -164,6 +200,8 @@ services:
|
||||
condition: service_healthy
|
||||
mcp-knowledge-base:
|
||||
condition: service_healthy
|
||||
mcp-git-ops:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]
|
||||
@@ -181,11 +219,14 @@ services:
|
||||
- DATABASE_URL=${DATABASE_URL}
|
||||
- REDIS_URL=redis://redis:6379/0
|
||||
- CELERY_QUEUE=git
|
||||
- GIT_OPS_URL=http://mcp-git-ops:8003
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
mcp-git-ops:
|
||||
condition: service_healthy
|
||||
networks:
|
||||
- app-network
|
||||
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "git", "-l", "info", "-c", "2"]
|
||||
@@ -260,6 +301,7 @@ services:
|
||||
volumes:
|
||||
postgres_data_dev:
|
||||
redis_data_dev:
|
||||
git_workspaces_dev:
|
||||
frontend_dev_modules:
|
||||
frontend_dev_next:
|
||||
|
||||
|
||||
55
frontend/package-lock.json
generated
55
frontend/package-lock.json
generated
@@ -21,6 +21,7 @@
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slot": "^1.2.4",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@radix-ui/react-toggle-group": "^1.1.11",
|
||||
"@tanstack/react-query": "^5.90.5",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"axios": "^1.13.1",
|
||||
@@ -4688,6 +4689,60 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-toggle": {
|
||||
"version": "1.1.10",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-toggle/-/react-toggle-1.1.10.tgz",
|
||||
"integrity": "sha512-lS1odchhFTeZv3xwHH31YPObmJn8gOg7Lq12inrr0+BH/l3Tsq32VfjqH1oh80ARM3mlkfMic15n0kg4sD1poQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.3",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-use-controllable-state": "1.2.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-toggle-group": {
|
||||
"version": "1.1.11",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-toggle-group/-/react-toggle-group-1.1.11.tgz",
|
||||
"integrity": "sha512-5umnS0T8JQzQT6HbPyO7Hh9dgd82NmS36DQr+X/YJ9ctFNCiiQd6IJAYYZ33LUwm8M+taCz5t2ui29fHZc4Y6Q==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.3",
|
||||
"@radix-ui/react-context": "1.1.2",
|
||||
"@radix-ui/react-direction": "1.1.1",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-roving-focus": "1.1.11",
|
||||
"@radix-ui/react-toggle": "1.1.10",
|
||||
"@radix-ui/react-use-controllable-state": "1.2.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-use-callback-ref": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.1.1.tgz",
|
||||
|
||||
@@ -35,6 +35,7 @@
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slot": "^1.2.4",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@radix-ui/react-toggle-group": "^1.1.11",
|
||||
"@tanstack/react-query": "^5.90.5",
|
||||
"@types/react-syntax-highlighter": "^15.5.13",
|
||||
"axios": "^1.13.1",
|
||||
|
||||
@@ -73,6 +73,13 @@ export default function AgentTypeDetailPage() {
|
||||
mcp_servers: data.mcp_servers,
|
||||
tool_permissions: data.tool_permissions,
|
||||
is_active: data.is_active,
|
||||
// Category and display fields
|
||||
category: data.category,
|
||||
icon: data.icon,
|
||||
color: data.color,
|
||||
sort_order: data.sort_order,
|
||||
typical_tasks: data.typical_tasks,
|
||||
collaboration_hints: data.collaboration_hints,
|
||||
});
|
||||
toast.success('Agent type created', {
|
||||
description: `${result.name} has been created successfully`,
|
||||
@@ -94,6 +101,13 @@ export default function AgentTypeDetailPage() {
|
||||
mcp_servers: data.mcp_servers,
|
||||
tool_permissions: data.tool_permissions,
|
||||
is_active: data.is_active,
|
||||
// Category and display fields
|
||||
category: data.category,
|
||||
icon: data.icon,
|
||||
color: data.color,
|
||||
sort_order: data.sort_order,
|
||||
typical_tasks: data.typical_tasks,
|
||||
collaboration_hints: data.collaboration_hints,
|
||||
},
|
||||
});
|
||||
toast.success('Agent type updated', {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/**
|
||||
* Agent Types List Page
|
||||
*
|
||||
* Displays a list of agent types with search and filter functionality.
|
||||
* Allows navigation to agent type detail and creation pages.
|
||||
* Displays a list of agent types with search, status, and category filters.
|
||||
* Supports grid and list view modes with user preference persistence.
|
||||
*/
|
||||
|
||||
'use client';
|
||||
@@ -10,9 +10,10 @@
|
||||
import { useState, useCallback, useMemo } from 'react';
|
||||
import { useRouter } from '@/lib/i18n/routing';
|
||||
import { toast } from 'sonner';
|
||||
import { AgentTypeList } from '@/components/agents';
|
||||
import { AgentTypeList, type ViewMode } from '@/components/agents';
|
||||
import { useAgentTypes } from '@/lib/api/hooks/useAgentTypes';
|
||||
import { useDebounce } from '@/lib/hooks/useDebounce';
|
||||
import type { AgentTypeCategory } from '@/lib/api/types/agentTypes';
|
||||
|
||||
export default function AgentTypesPage() {
|
||||
const router = useRouter();
|
||||
@@ -20,6 +21,8 @@ export default function AgentTypesPage() {
|
||||
// Filter state
|
||||
const [searchQuery, setSearchQuery] = useState('');
|
||||
const [statusFilter, setStatusFilter] = useState('all');
|
||||
const [categoryFilter, setCategoryFilter] = useState('all');
|
||||
const [viewMode, setViewMode] = useState<ViewMode>('grid');
|
||||
|
||||
// Debounce search for API calls
|
||||
const debouncedSearch = useDebounce(searchQuery, 300);
|
||||
@@ -31,21 +34,25 @@ export default function AgentTypesPage() {
|
||||
return undefined; // 'all' returns undefined to not filter
|
||||
}, [statusFilter]);
|
||||
|
||||
// Determine category filter value
|
||||
const categoryFilterValue = useMemo(() => {
|
||||
if (categoryFilter === 'all') return undefined;
|
||||
return categoryFilter as AgentTypeCategory;
|
||||
}, [categoryFilter]);
|
||||
|
||||
// Fetch agent types
|
||||
const { data, isLoading, error } = useAgentTypes({
|
||||
search: debouncedSearch || undefined,
|
||||
is_active: isActiveFilter,
|
||||
category: categoryFilterValue,
|
||||
page: 1,
|
||||
limit: 50,
|
||||
});
|
||||
|
||||
// Filter results client-side for 'all' status
|
||||
// Get filtered and sorted agent types (sort by sort_order ascending - smaller first)
|
||||
const filteredAgentTypes = useMemo(() => {
|
||||
if (!data?.data) return [];
|
||||
|
||||
// When status is 'all', we need to fetch both and combine
|
||||
// For now, the API returns based on is_active filter
|
||||
return data.data;
|
||||
return [...data.data].sort((a, b) => a.sort_order - b.sort_order);
|
||||
}, [data?.data]);
|
||||
|
||||
// Handle navigation to agent type detail
|
||||
@@ -71,6 +78,16 @@ export default function AgentTypesPage() {
|
||||
setStatusFilter(status);
|
||||
}, []);
|
||||
|
||||
// Handle category filter change
|
||||
const handleCategoryFilterChange = useCallback((category: string) => {
|
||||
setCategoryFilter(category);
|
||||
}, []);
|
||||
|
||||
// Handle view mode change
|
||||
const handleViewModeChange = useCallback((mode: ViewMode) => {
|
||||
setViewMode(mode);
|
||||
}, []);
|
||||
|
||||
// Show error toast if fetch fails
|
||||
if (error) {
|
||||
toast.error('Failed to load agent types', {
|
||||
@@ -87,6 +104,10 @@ export default function AgentTypesPage() {
|
||||
onSearchChange={handleSearchChange}
|
||||
statusFilter={statusFilter}
|
||||
onStatusFilterChange={handleStatusFilterChange}
|
||||
categoryFilter={categoryFilter}
|
||||
onCategoryFilterChange={handleCategoryFilterChange}
|
||||
viewMode={viewMode}
|
||||
onViewModeChange={handleViewModeChange}
|
||||
onSelect={handleSelect}
|
||||
onCreate={handleCreate}
|
||||
/>
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
* AgentTypeDetail Component
|
||||
*
|
||||
* Displays detailed information about a single agent type.
|
||||
* Shows model configuration, permissions, personality, and instance stats.
|
||||
* Features a hero header with icon/color, category, typical tasks,
|
||||
* collaboration hints, model configuration, and instance stats.
|
||||
*/
|
||||
|
||||
'use client';
|
||||
@@ -36,8 +37,13 @@ import {
|
||||
Cpu,
|
||||
CheckCircle2,
|
||||
AlertTriangle,
|
||||
Sparkles,
|
||||
Users,
|
||||
Check,
|
||||
} from 'lucide-react';
|
||||
import type { AgentTypeResponse } from '@/lib/api/types/agentTypes';
|
||||
import { DynamicIcon } from '@/components/ui/dynamic-icon';
|
||||
import type { AgentTypeResponse, AgentTypeCategory } from '@/lib/api/types/agentTypes';
|
||||
import { CATEGORY_METADATA } from '@/lib/api/types/agentTypes';
|
||||
import { AVAILABLE_MCP_SERVERS } from '@/lib/validations/agentType';
|
||||
|
||||
interface AgentTypeDetailProps {
|
||||
@@ -51,6 +57,30 @@ interface AgentTypeDetailProps {
|
||||
className?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Category badge with color
|
||||
*/
|
||||
function CategoryBadge({ category }: { category: AgentTypeCategory | null }) {
|
||||
if (!category) return null;
|
||||
|
||||
const meta = CATEGORY_METADATA[category];
|
||||
if (!meta) return null;
|
||||
|
||||
return (
|
||||
<Badge
|
||||
variant="outline"
|
||||
className="font-medium"
|
||||
style={{
|
||||
borderColor: meta.color,
|
||||
color: meta.color,
|
||||
backgroundColor: `${meta.color}10`,
|
||||
}}
|
||||
>
|
||||
{meta.label}
|
||||
</Badge>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Status badge component for agent types
|
||||
*/
|
||||
@@ -81,11 +111,22 @@ function AgentTypeStatusBadge({ isActive }: { isActive: boolean }) {
|
||||
function AgentTypeDetailSkeleton() {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div className="flex items-center gap-4">
|
||||
<Skeleton className="h-10 w-10" />
|
||||
<div className="flex-1">
|
||||
<Skeleton className="h-8 w-64" />
|
||||
<Skeleton className="mt-2 h-4 w-48" />
|
||||
{/* Hero skeleton */}
|
||||
<div className="rounded-xl border p-6">
|
||||
<div className="flex items-start gap-6">
|
||||
<Skeleton className="h-20 w-20 rounded-xl" />
|
||||
<div className="flex-1 space-y-3">
|
||||
<Skeleton className="h-8 w-64" />
|
||||
<Skeleton className="h-4 w-96" />
|
||||
<div className="flex gap-2">
|
||||
<Skeleton className="h-6 w-20" />
|
||||
<Skeleton className="h-6 w-24" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Skeleton className="h-9 w-24" />
|
||||
<Skeleton className="h-9 w-20" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="grid gap-6 lg:grid-cols-3">
|
||||
@@ -161,57 +202,134 @@ export function AgentTypeDetail({
|
||||
top_p?: number;
|
||||
};
|
||||
|
||||
const agentColor = agentType.color || '#3B82F6';
|
||||
|
||||
return (
|
||||
<div className={className}>
|
||||
{/* Header */}
|
||||
<div className="mb-6 flex items-center gap-4">
|
||||
<Button variant="ghost" size="icon" onClick={onBack}>
|
||||
<ArrowLeft className="h-4 w-4" />
|
||||
<span className="sr-only">Go back</span>
|
||||
</Button>
|
||||
<div className="flex-1">
|
||||
<div className="flex items-center gap-3">
|
||||
<h1 className="text-3xl font-bold">{agentType.name}</h1>
|
||||
<AgentTypeStatusBadge isActive={agentType.is_active} />
|
||||
{/* Back button */}
|
||||
<Button variant="ghost" size="sm" onClick={onBack} className="mb-4">
|
||||
<ArrowLeft className="mr-2 h-4 w-4" />
|
||||
Back to Agent Types
|
||||
</Button>
|
||||
|
||||
{/* Hero Header */}
|
||||
<div
|
||||
className="mb-6 overflow-hidden rounded-xl border"
|
||||
style={{
|
||||
background: `linear-gradient(135deg, ${agentColor}08 0%, transparent 60%)`,
|
||||
borderColor: `${agentColor}30`,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="h-1.5 w-full"
|
||||
style={{ background: `linear-gradient(90deg, ${agentColor}, ${agentColor}60)` }}
|
||||
/>
|
||||
<div className="p-6">
|
||||
<div className="flex flex-col gap-6 md:flex-row md:items-start">
|
||||
{/* Icon */}
|
||||
<div
|
||||
className="flex h-20 w-20 shrink-0 items-center justify-center rounded-xl"
|
||||
style={{
|
||||
backgroundColor: `${agentColor}15`,
|
||||
boxShadow: `0 8px 32px ${agentColor}20`,
|
||||
}}
|
||||
>
|
||||
<DynamicIcon
|
||||
name={agentType.icon}
|
||||
className="h-10 w-10"
|
||||
style={{ color: agentColor }}
|
||||
fallback="bot"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Info */}
|
||||
<div className="flex-1 space-y-3">
|
||||
<div>
|
||||
<h1 className="text-3xl font-bold">{agentType.name}</h1>
|
||||
<p className="mt-1 text-muted-foreground">
|
||||
{agentType.description || 'No description provided'}
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<AgentTypeStatusBadge isActive={agentType.is_active} />
|
||||
<CategoryBadge category={agentType.category} />
|
||||
<span className="text-sm text-muted-foreground">
|
||||
Last updated:{' '}
|
||||
{new Date(agentType.updated_at).toLocaleDateString('en-US', {
|
||||
year: 'numeric',
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
})}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Actions */}
|
||||
<div className="flex shrink-0 gap-2">
|
||||
<Button variant="outline" size="sm" onClick={onDuplicate}>
|
||||
<Copy className="mr-2 h-4 w-4" />
|
||||
Duplicate
|
||||
</Button>
|
||||
<Button size="sm" onClick={onEdit}>
|
||||
<Edit className="mr-2 h-4 w-4" />
|
||||
Edit
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<p className="text-muted-foreground">
|
||||
Last modified:{' '}
|
||||
{new Date(agentType.updated_at).toLocaleDateString('en-US', {
|
||||
year: 'numeric',
|
||||
month: 'long',
|
||||
day: 'numeric',
|
||||
})}
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<Button variant="outline" size="sm" onClick={onDuplicate}>
|
||||
<Copy className="mr-2 h-4 w-4" />
|
||||
Duplicate
|
||||
</Button>
|
||||
<Button size="sm" onClick={onEdit}>
|
||||
<Edit className="mr-2 h-4 w-4" />
|
||||
Edit
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-6 lg:grid-cols-3">
|
||||
{/* Main Content */}
|
||||
<div className="space-y-6 lg:col-span-2">
|
||||
{/* Description Card */}
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2">
|
||||
<FileText className="h-5 w-5" />
|
||||
Description
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p className="text-muted-foreground">
|
||||
{agentType.description || 'No description provided'}
|
||||
</p>
|
||||
</CardContent>
|
||||
</Card>
|
||||
{/* What This Agent Does Best */}
|
||||
{agentType.typical_tasks.length > 0 && (
|
||||
<Card className="border-primary/20 bg-gradient-to-br from-primary/5 to-transparent">
|
||||
<CardHeader className="pb-3">
|
||||
<CardTitle className="flex items-center gap-2 text-lg">
|
||||
<Sparkles className="h-5 w-5 text-primary" />
|
||||
What This Agent Does Best
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<ul className="space-y-2">
|
||||
{agentType.typical_tasks.map((task, index) => (
|
||||
<li key={index} className="flex items-start gap-2">
|
||||
<Check
|
||||
className="mt-0.5 h-4 w-4 shrink-0 text-primary"
|
||||
style={{ color: agentColor }}
|
||||
/>
|
||||
<span className="text-sm">{task}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{/* Works Well With */}
|
||||
{agentType.collaboration_hints.length > 0 && (
|
||||
<Card>
|
||||
<CardHeader className="pb-3">
|
||||
<CardTitle className="flex items-center gap-2 text-lg">
|
||||
<Users className="h-5 w-5" />
|
||||
Works Well With
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
Agents that complement this type for effective collaboration
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{agentType.collaboration_hints.map((hint, index) => (
|
||||
<Badge key={index} variant="secondary" className="text-sm">
|
||||
{hint}
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{/* Expertise Card */}
|
||||
<Card>
|
||||
@@ -355,7 +473,9 @@ export function AgentTypeDetail({
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="text-center">
|
||||
<p className="text-4xl font-bold text-primary">{agentType.instance_count}</p>
|
||||
<p className="text-4xl font-bold" style={{ color: agentColor }}>
|
||||
{agentType.instance_count}
|
||||
</p>
|
||||
<p className="text-sm text-muted-foreground">Active instances</p>
|
||||
</div>
|
||||
<Button variant="outline" className="mt-4 w-full" size="sm" disabled>
|
||||
@@ -364,6 +484,36 @@ export function AgentTypeDetail({
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Agent Info */}
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="flex items-center gap-2 text-lg">
|
||||
<FileText className="h-5 w-5" />
|
||||
Details
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-3 text-sm">
|
||||
<div className="flex justify-between">
|
||||
<span className="text-muted-foreground">Slug</span>
|
||||
<code className="rounded bg-muted px-1.5 py-0.5 text-xs">{agentType.slug}</code>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span className="text-muted-foreground">Sort Order</span>
|
||||
<span>{agentType.sort_order}</span>
|
||||
</div>
|
||||
<div className="flex justify-between">
|
||||
<span className="text-muted-foreground">Created</span>
|
||||
<span>
|
||||
{new Date(agentType.created_at).toLocaleDateString('en-US', {
|
||||
year: 'numeric',
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
})}
|
||||
</span>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Danger Zone */}
|
||||
<Card className="border-destructive/50">
|
||||
<CardHeader>
|
||||
|
||||
@@ -36,6 +36,7 @@ import {
|
||||
type AgentTypeCreateFormValues,
|
||||
AVAILABLE_MODELS,
|
||||
AVAILABLE_MCP_SERVERS,
|
||||
AGENT_TYPE_CATEGORIES,
|
||||
defaultAgentTypeValues,
|
||||
generateSlug,
|
||||
} from '@/lib/validations/agentType';
|
||||
@@ -57,6 +58,13 @@ const TAB_FIELD_MAPPING = {
|
||||
description: 'basic',
|
||||
expertise: 'basic',
|
||||
is_active: 'basic',
|
||||
// Category and display fields
|
||||
category: 'basic',
|
||||
icon: 'basic',
|
||||
color: 'basic',
|
||||
sort_order: 'basic',
|
||||
typical_tasks: 'basic',
|
||||
collaboration_hints: 'basic',
|
||||
primary_model: 'model',
|
||||
fallback_models: 'model',
|
||||
model_params: 'model',
|
||||
@@ -96,6 +104,13 @@ function transformAgentTypeToFormValues(
|
||||
mcp_servers: agentType.mcp_servers,
|
||||
tool_permissions: agentType.tool_permissions,
|
||||
is_active: agentType.is_active,
|
||||
// Category and display fields
|
||||
category: agentType.category,
|
||||
icon: agentType.icon,
|
||||
color: agentType.color,
|
||||
sort_order: agentType.sort_order ?? 0,
|
||||
typical_tasks: agentType.typical_tasks ?? [],
|
||||
collaboration_hints: agentType.collaboration_hints ?? [],
|
||||
});
|
||||
|
||||
return {
|
||||
@@ -114,6 +129,8 @@ export function AgentTypeForm({
|
||||
const isEditing = !!agentType;
|
||||
const [activeTab, setActiveTab] = useState('basic');
|
||||
const [expertiseInput, setExpertiseInput] = useState('');
|
||||
const [typicalTaskInput, setTypicalTaskInput] = useState('');
|
||||
const [collaborationHintInput, setCollaborationHintInput] = useState('');
|
||||
|
||||
// Memoize initial values transformation
|
||||
const initialValues = useMemo(() => transformAgentTypeToFormValues(agentType), [agentType]);
|
||||
@@ -144,6 +161,10 @@ export function AgentTypeForm({
|
||||
const watchExpertise = watch('expertise') || [];
|
||||
/* istanbul ignore next -- defensive fallback, mcp_servers always has default */
|
||||
const watchMcpServers = watch('mcp_servers') || [];
|
||||
/* istanbul ignore next -- defensive fallback, typical_tasks always has default */
|
||||
const watchTypicalTasks = watch('typical_tasks') || [];
|
||||
/* istanbul ignore next -- defensive fallback, collaboration_hints always has default */
|
||||
const watchCollaborationHints = watch('collaboration_hints') || [];
|
||||
|
||||
// Reset form when agentType changes (e.g., switching to edit mode)
|
||||
useEffect(() => {
|
||||
@@ -189,6 +210,40 @@ export function AgentTypeForm({
|
||||
}
|
||||
};
|
||||
|
||||
const handleAddTypicalTask = () => {
|
||||
if (typicalTaskInput.trim()) {
|
||||
const newTask = typicalTaskInput.trim();
|
||||
if (!watchTypicalTasks.includes(newTask)) {
|
||||
setValue('typical_tasks', [...watchTypicalTasks, newTask]);
|
||||
}
|
||||
setTypicalTaskInput('');
|
||||
}
|
||||
};
|
||||
|
||||
const handleRemoveTypicalTask = (task: string) => {
|
||||
setValue(
|
||||
'typical_tasks',
|
||||
watchTypicalTasks.filter((t) => t !== task)
|
||||
);
|
||||
};
|
||||
|
||||
const handleAddCollaborationHint = () => {
|
||||
if (collaborationHintInput.trim()) {
|
||||
const newHint = collaborationHintInput.trim().toLowerCase();
|
||||
if (!watchCollaborationHints.includes(newHint)) {
|
||||
setValue('collaboration_hints', [...watchCollaborationHints, newHint]);
|
||||
}
|
||||
setCollaborationHintInput('');
|
||||
}
|
||||
};
|
||||
|
||||
const handleRemoveCollaborationHint = (hint: string) => {
|
||||
setValue(
|
||||
'collaboration_hints',
|
||||
watchCollaborationHints.filter((h) => h !== hint)
|
||||
);
|
||||
};
|
||||
|
||||
// Handle form submission with validation
|
||||
const onFormSubmit = useCallback(
|
||||
(e: React.FormEvent<HTMLFormElement>) => {
|
||||
@@ -376,6 +431,188 @@ export function AgentTypeForm({
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
{/* Category & Display Card */}
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Category & Display</CardTitle>
|
||||
<CardDescription>
|
||||
Organize and customize how this agent type appears in the UI
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-6">
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="category">Category</Label>
|
||||
<Controller
|
||||
name="category"
|
||||
control={control}
|
||||
render={({ field }) => (
|
||||
<Select
|
||||
value={field.value ?? ''}
|
||||
onValueChange={(val) => field.onChange(val || null)}
|
||||
>
|
||||
<SelectTrigger id="category">
|
||||
<SelectValue placeholder="Select category" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{AGENT_TYPE_CATEGORIES.map((cat) => (
|
||||
<SelectItem key={cat.value} value={cat.value}>
|
||||
{cat.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
)}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Group agents by their primary role
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="sort_order">Sort Order</Label>
|
||||
<Input
|
||||
id="sort_order"
|
||||
type="number"
|
||||
min={0}
|
||||
max={1000}
|
||||
{...register('sort_order', { valueAsNumber: true })}
|
||||
aria-invalid={!!errors.sort_order}
|
||||
/>
|
||||
{errors.sort_order && (
|
||||
<p className="text-sm text-destructive" role="alert">
|
||||
{errors.sort_order.message}
|
||||
</p>
|
||||
)}
|
||||
<p className="text-xs text-muted-foreground">Display order within category</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-4 md:grid-cols-2">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="icon">Icon</Label>
|
||||
<Input
|
||||
id="icon"
|
||||
placeholder="e.g., git-branch"
|
||||
{...register('icon')}
|
||||
aria-invalid={!!errors.icon}
|
||||
/>
|
||||
{errors.icon && (
|
||||
<p className="text-sm text-destructive" role="alert">
|
||||
{errors.icon.message}
|
||||
</p>
|
||||
)}
|
||||
<p className="text-xs text-muted-foreground">Lucide icon name for UI display</p>
|
||||
</div>
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="color">Color</Label>
|
||||
<div className="flex gap-2">
|
||||
<Input
|
||||
id="color"
|
||||
placeholder="#3B82F6"
|
||||
{...register('color')}
|
||||
aria-invalid={!!errors.color}
|
||||
className="flex-1"
|
||||
/>
|
||||
<Controller
|
||||
name="color"
|
||||
control={control}
|
||||
render={({ field }) => (
|
||||
<input
|
||||
type="color"
|
||||
value={field.value ?? '#3B82F6'}
|
||||
onChange={(e) => field.onChange(e.target.value)}
|
||||
className="h-9 w-9 cursor-pointer rounded border"
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
{errors.color && (
|
||||
<p className="text-sm text-destructive" role="alert">
|
||||
{errors.color.message}
|
||||
</p>
|
||||
)}
|
||||
<p className="text-xs text-muted-foreground">Hex color for visual distinction</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>Typical Tasks</Label>
|
||||
<p className="text-sm text-muted-foreground">Tasks this agent type excels at</p>
|
||||
<div className="flex gap-2">
|
||||
<Input
|
||||
placeholder="e.g., Design system architecture"
|
||||
value={typicalTaskInput}
|
||||
onChange={(e) => setTypicalTaskInput(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault();
|
||||
handleAddTypicalTask();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Button type="button" variant="outline" onClick={handleAddTypicalTask}>
|
||||
Add
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-2 pt-2">
|
||||
{watchTypicalTasks.map((task) => (
|
||||
<Badge key={task} variant="secondary" className="gap-1">
|
||||
{task}
|
||||
<button
|
||||
type="button"
|
||||
className="ml-1 rounded-full hover:bg-muted"
|
||||
onClick={() => handleRemoveTypicalTask(task)}
|
||||
aria-label={`Remove ${task}`}
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</button>
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label>Collaboration Hints</Label>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Agent slugs that work well with this type
|
||||
</p>
|
||||
<div className="flex gap-2">
|
||||
<Input
|
||||
placeholder="e.g., backend-engineer"
|
||||
value={collaborationHintInput}
|
||||
onChange={(e) => setCollaborationHintInput(e.target.value)}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter') {
|
||||
e.preventDefault();
|
||||
handleAddCollaborationHint();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Button type="button" variant="outline" onClick={handleAddCollaborationHint}>
|
||||
Add
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex flex-wrap gap-2 pt-2">
|
||||
{watchCollaborationHints.map((hint) => (
|
||||
<Badge key={hint} variant="outline" className="gap-1">
|
||||
{hint}
|
||||
<button
|
||||
type="button"
|
||||
className="ml-1 rounded-full hover:bg-muted"
|
||||
onClick={() => handleRemoveCollaborationHint(hint)}
|
||||
aria-label={`Remove ${hint}`}
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</button>
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
|
||||
{/* Model Configuration Tab */}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/**
|
||||
* AgentTypeList Component
|
||||
*
|
||||
* Displays a grid of agent type cards with search and filter functionality.
|
||||
* Used on the main agent types page for browsing and selecting agent types.
|
||||
* Displays agent types in grid or list view with search, status, and category filters.
|
||||
* Shows icon, color accent, and category for each agent type.
|
||||
*/
|
||||
|
||||
'use client';
|
||||
@@ -20,8 +20,14 @@ import {
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@/components/ui/select';
|
||||
import { Bot, Plus, Search, Cpu } from 'lucide-react';
|
||||
import type { AgentTypeResponse } from '@/lib/api/types/agentTypes';
|
||||
import { ToggleGroup, ToggleGroupItem } from '@/components/ui/toggle-group';
|
||||
import { Bot, Plus, Search, Cpu, LayoutGrid, List } from 'lucide-react';
|
||||
import { DynamicIcon } from '@/components/ui/dynamic-icon';
|
||||
import type { AgentTypeResponse, AgentTypeCategory } from '@/lib/api/types/agentTypes';
|
||||
import { CATEGORY_METADATA } from '@/lib/api/types/agentTypes';
|
||||
import { AGENT_TYPE_CATEGORIES } from '@/lib/validations/agentType';
|
||||
|
||||
export type ViewMode = 'grid' | 'list';
|
||||
|
||||
interface AgentTypeListProps {
|
||||
agentTypes: AgentTypeResponse[];
|
||||
@@ -30,6 +36,10 @@ interface AgentTypeListProps {
|
||||
onSearchChange: (query: string) => void;
|
||||
statusFilter: string;
|
||||
onStatusFilterChange: (status: string) => void;
|
||||
categoryFilter: string;
|
||||
onCategoryFilterChange: (category: string) => void;
|
||||
viewMode: ViewMode;
|
||||
onViewModeChange: (mode: ViewMode) => void;
|
||||
onSelect: (id: string) => void;
|
||||
onCreate: () => void;
|
||||
className?: string;
|
||||
@@ -60,11 +70,36 @@ function AgentTypeStatusBadge({ isActive }: { isActive: boolean }) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Loading skeleton for agent type cards
|
||||
* Category badge with color
|
||||
*/
|
||||
function CategoryBadge({ category }: { category: AgentTypeCategory | null }) {
|
||||
if (!category) return null;
|
||||
|
||||
const meta = CATEGORY_METADATA[category];
|
||||
if (!meta) return null;
|
||||
|
||||
return (
|
||||
<Badge
|
||||
variant="outline"
|
||||
className="text-xs font-medium"
|
||||
style={{
|
||||
borderColor: meta.color,
|
||||
color: meta.color,
|
||||
backgroundColor: `${meta.color}10`,
|
||||
}}
|
||||
>
|
||||
{meta.label}
|
||||
</Badge>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loading skeleton for agent type cards (grid view)
|
||||
*/
|
||||
function AgentTypeCardSkeleton() {
|
||||
return (
|
||||
<Card className="h-[200px]">
|
||||
<Card className="h-[220px] overflow-hidden">
|
||||
<div className="h-1 w-full bg-muted" />
|
||||
<CardHeader className="pb-3">
|
||||
<div className="flex items-start justify-between">
|
||||
<Skeleton className="h-10 w-10 rounded-lg" />
|
||||
@@ -91,6 +126,23 @@ function AgentTypeCardSkeleton() {
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loading skeleton for list view
|
||||
*/
|
||||
function AgentTypeListSkeleton() {
|
||||
return (
|
||||
<div className="flex items-center gap-4 rounded-lg border p-4">
|
||||
<Skeleton className="h-12 w-12 rounded-lg" />
|
||||
<div className="flex-1 space-y-2">
|
||||
<Skeleton className="h-5 w-48" />
|
||||
<Skeleton className="h-4 w-96" />
|
||||
</div>
|
||||
<Skeleton className="h-5 w-20" />
|
||||
<Skeleton className="h-5 w-16" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract model display name from model ID
|
||||
*/
|
||||
@@ -103,6 +155,169 @@ function getModelDisplayName(modelId: string): string {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Grid card view for agent type
|
||||
*/
|
||||
function AgentTypeGridCard({
|
||||
type,
|
||||
onSelect,
|
||||
}: {
|
||||
type: AgentTypeResponse;
|
||||
onSelect: (id: string) => void;
|
||||
}) {
|
||||
const agentColor = type.color || '#3B82F6';
|
||||
|
||||
return (
|
||||
<Card
|
||||
className="cursor-pointer overflow-hidden transition-all hover:shadow-lg"
|
||||
onClick={() => onSelect(type.id)}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
onSelect(type.id);
|
||||
}
|
||||
}}
|
||||
aria-label={`View ${type.name} agent type`}
|
||||
style={{
|
||||
borderTopColor: agentColor,
|
||||
borderTopWidth: '3px',
|
||||
}}
|
||||
>
|
||||
<CardHeader className="pb-3">
|
||||
<div className="flex items-start justify-between">
|
||||
<div
|
||||
className="flex h-11 w-11 items-center justify-center rounded-lg"
|
||||
style={{
|
||||
backgroundColor: `${agentColor}15`,
|
||||
}}
|
||||
>
|
||||
<DynamicIcon
|
||||
name={type.icon}
|
||||
className="h-5 w-5"
|
||||
style={{ color: agentColor }}
|
||||
fallback="bot"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col items-end gap-1">
|
||||
<AgentTypeStatusBadge isActive={type.is_active} />
|
||||
<CategoryBadge category={type.category} />
|
||||
</div>
|
||||
</div>
|
||||
<CardTitle className="mt-3 line-clamp-1">{type.name}</CardTitle>
|
||||
<CardDescription className="line-clamp-2">
|
||||
{type.description || 'No description provided'}
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="space-y-3">
|
||||
{/* Expertise tags */}
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{type.expertise.slice(0, 3).map((skill) => (
|
||||
<Badge key={skill} variant="secondary" className="text-xs">
|
||||
{skill}
|
||||
</Badge>
|
||||
))}
|
||||
{type.expertise.length > 3 && (
|
||||
<Badge variant="outline" className="text-xs">
|
||||
+{type.expertise.length - 3}
|
||||
</Badge>
|
||||
)}
|
||||
{type.expertise.length === 0 && (
|
||||
<span className="text-xs text-muted-foreground">No expertise defined</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Metadata */}
|
||||
<div className="flex items-center justify-between text-sm text-muted-foreground">
|
||||
<div className="flex items-center gap-1">
|
||||
<Cpu className="h-3.5 w-3.5" />
|
||||
<span className="text-xs">{getModelDisplayName(type.primary_model)}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<Bot className="h-3.5 w-3.5" />
|
||||
<span className="text-xs">{type.instance_count} instances</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* List row view for agent type
|
||||
*/
|
||||
function AgentTypeListRow({
|
||||
type,
|
||||
onSelect,
|
||||
}: {
|
||||
type: AgentTypeResponse;
|
||||
onSelect: (id: string) => void;
|
||||
}) {
|
||||
const agentColor = type.color || '#3B82F6';
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex cursor-pointer items-center gap-4 rounded-lg border p-4 transition-all hover:border-primary hover:shadow-md"
|
||||
onClick={() => onSelect(type.id)}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
onSelect(type.id);
|
||||
}
|
||||
}}
|
||||
aria-label={`View ${type.name} agent type`}
|
||||
style={{
|
||||
borderLeftColor: agentColor,
|
||||
borderLeftWidth: '4px',
|
||||
}}
|
||||
>
|
||||
{/* Icon */}
|
||||
<div
|
||||
className="flex h-12 w-12 shrink-0 items-center justify-center rounded-lg"
|
||||
style={{ backgroundColor: `${agentColor}15` }}
|
||||
>
|
||||
<DynamicIcon
|
||||
name={type.icon}
|
||||
className="h-6 w-6"
|
||||
style={{ color: agentColor }}
|
||||
fallback="bot"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Main content */}
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="flex items-center gap-2">
|
||||
<h3 className="font-semibold">{type.name}</h3>
|
||||
<CategoryBadge category={type.category} />
|
||||
</div>
|
||||
<p className="line-clamp-1 text-sm text-muted-foreground">
|
||||
{type.description || 'No description'}
|
||||
</p>
|
||||
<div className="mt-1 flex items-center gap-3 text-xs text-muted-foreground">
|
||||
<span className="flex items-center gap-1">
|
||||
<Cpu className="h-3 w-3" />
|
||||
{getModelDisplayName(type.primary_model)}
|
||||
</span>
|
||||
<span>{type.expertise.length} expertise areas</span>
|
||||
<span>{type.instance_count} instances</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Status */}
|
||||
<div className="shrink-0">
|
||||
<AgentTypeStatusBadge isActive={type.is_active} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function AgentTypeList({
|
||||
agentTypes,
|
||||
isLoading = false,
|
||||
@@ -110,6 +325,10 @@ export function AgentTypeList({
|
||||
onSearchChange,
|
||||
statusFilter,
|
||||
onStatusFilterChange,
|
||||
categoryFilter,
|
||||
onCategoryFilterChange,
|
||||
viewMode,
|
||||
onViewModeChange,
|
||||
onSelect,
|
||||
onCreate,
|
||||
className,
|
||||
@@ -131,7 +350,7 @@ export function AgentTypeList({
|
||||
</div>
|
||||
|
||||
{/* Filters */}
|
||||
<div className="mb-6 flex flex-col gap-4 sm:flex-row">
|
||||
<div className="mb-6 flex flex-col gap-4 sm:flex-row sm:items-center">
|
||||
<div className="relative flex-1">
|
||||
<Search className="absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
|
||||
<Input
|
||||
@@ -142,8 +361,25 @@ export function AgentTypeList({
|
||||
aria-label="Search agent types"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Category Filter */}
|
||||
<Select value={categoryFilter} onValueChange={onCategoryFilterChange}>
|
||||
<SelectTrigger className="w-full sm:w-44" aria-label="Filter by category">
|
||||
<SelectValue placeholder="All Categories" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">All Categories</SelectItem>
|
||||
{AGENT_TYPE_CATEGORIES.map((cat) => (
|
||||
<SelectItem key={cat.value} value={cat.value}>
|
||||
{cat.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
{/* Status Filter */}
|
||||
<Select value={statusFilter} onValueChange={onStatusFilterChange}>
|
||||
<SelectTrigger className="w-full sm:w-40" aria-label="Filter by status">
|
||||
<SelectTrigger className="w-full sm:w-36" aria-label="Filter by status">
|
||||
<SelectValue placeholder="Status" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
@@ -152,10 +388,25 @@ export function AgentTypeList({
|
||||
<SelectItem value="inactive">Inactive</SelectItem>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
{/* View Mode Toggle */}
|
||||
<ToggleGroup
|
||||
type="single"
|
||||
value={viewMode}
|
||||
onValueChange={(value: string) => value && onViewModeChange(value as ViewMode)}
|
||||
className="hidden sm:flex"
|
||||
>
|
||||
<ToggleGroupItem value="grid" aria-label="Grid view" size="sm">
|
||||
<LayoutGrid className="h-4 w-4" />
|
||||
</ToggleGroupItem>
|
||||
<ToggleGroupItem value="list" aria-label="List view" size="sm">
|
||||
<List className="h-4 w-4" />
|
||||
</ToggleGroupItem>
|
||||
</ToggleGroup>
|
||||
</div>
|
||||
|
||||
{/* Loading State */}
|
||||
{isLoading && (
|
||||
{/* Loading State - Grid */}
|
||||
{isLoading && viewMode === 'grid' && (
|
||||
<div className="grid gap-4 md:grid-cols-2 lg:grid-cols-3">
|
||||
{[1, 2, 3, 4, 5, 6].map((i) => (
|
||||
<AgentTypeCardSkeleton key={i} />
|
||||
@@ -163,71 +414,29 @@ export function AgentTypeList({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Agent Type Grid */}
|
||||
{!isLoading && agentTypes.length > 0 && (
|
||||
{/* Loading State - List */}
|
||||
{isLoading && viewMode === 'list' && (
|
||||
<div className="space-y-3">
|
||||
{[1, 2, 3, 4, 5, 6].map((i) => (
|
||||
<AgentTypeListSkeleton key={i} />
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Agent Type Grid View */}
|
||||
{!isLoading && agentTypes.length > 0 && viewMode === 'grid' && (
|
||||
<div className="grid gap-4 md:grid-cols-2 lg:grid-cols-3">
|
||||
{agentTypes.map((type) => (
|
||||
<Card
|
||||
key={type.id}
|
||||
className="cursor-pointer transition-all hover:border-primary hover:shadow-md"
|
||||
onClick={() => onSelect(type.id)}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === 'Enter' || e.key === ' ') {
|
||||
e.preventDefault();
|
||||
onSelect(type.id);
|
||||
}
|
||||
}}
|
||||
aria-label={`View ${type.name} agent type`}
|
||||
>
|
||||
<CardHeader className="pb-3">
|
||||
<div className="flex items-start justify-between">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-primary/10">
|
||||
<Bot className="h-5 w-5 text-primary" />
|
||||
</div>
|
||||
<AgentTypeStatusBadge isActive={type.is_active} />
|
||||
</div>
|
||||
<CardTitle className="mt-3">{type.name}</CardTitle>
|
||||
<CardDescription className="line-clamp-2">
|
||||
{type.description || 'No description provided'}
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="space-y-3">
|
||||
{/* Expertise tags */}
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{type.expertise.slice(0, 3).map((skill) => (
|
||||
<Badge key={skill} variant="secondary" className="text-xs">
|
||||
{skill}
|
||||
</Badge>
|
||||
))}
|
||||
{type.expertise.length > 3 && (
|
||||
<Badge variant="outline" className="text-xs">
|
||||
+{type.expertise.length - 3}
|
||||
</Badge>
|
||||
)}
|
||||
{type.expertise.length === 0 && (
|
||||
<span className="text-xs text-muted-foreground">No expertise defined</span>
|
||||
)}
|
||||
</div>
|
||||
<AgentTypeGridCard key={type.id} type={type} onSelect={onSelect} />
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Separator />
|
||||
|
||||
{/* Metadata */}
|
||||
<div className="flex items-center justify-between text-sm text-muted-foreground">
|
||||
<div className="flex items-center gap-1">
|
||||
<Cpu className="h-3.5 w-3.5" />
|
||||
<span className="text-xs">{getModelDisplayName(type.primary_model)}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<Bot className="h-3.5 w-3.5" />
|
||||
<span className="text-xs">{type.instance_count} instances</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
{/* Agent Type List View */}
|
||||
{!isLoading && agentTypes.length > 0 && viewMode === 'list' && (
|
||||
<div className="space-y-3">
|
||||
{agentTypes.map((type) => (
|
||||
<AgentTypeListRow key={type.id} type={type} onSelect={onSelect} />
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
@@ -238,11 +447,11 @@ export function AgentTypeList({
|
||||
<Bot className="mx-auto h-12 w-12 text-muted-foreground" />
|
||||
<h3 className="mt-4 font-semibold">No agent types found</h3>
|
||||
<p className="text-muted-foreground">
|
||||
{searchQuery || statusFilter !== 'all'
|
||||
{searchQuery || statusFilter !== 'all' || categoryFilter !== 'all'
|
||||
? 'Try adjusting your search or filters'
|
||||
: 'Create your first agent type to get started'}
|
||||
</p>
|
||||
{!searchQuery && statusFilter === 'all' && (
|
||||
{!searchQuery && statusFilter === 'all' && categoryFilter === 'all' && (
|
||||
<Button onClick={onCreate} className="mt-4">
|
||||
<Plus className="mr-2 h-4 w-4" />
|
||||
Create Agent Type
|
||||
|
||||
@@ -5,5 +5,5 @@
|
||||
*/
|
||||
|
||||
export { AgentTypeForm } from './AgentTypeForm';
|
||||
export { AgentTypeList } from './AgentTypeList';
|
||||
export { AgentTypeList, type ViewMode } from './AgentTypeList';
|
||||
export { AgentTypeDetail } from './AgentTypeDetail';
|
||||
|
||||
84
frontend/src/components/ui/dynamic-icon.tsx
Normal file
84
frontend/src/components/ui/dynamic-icon.tsx
Normal file
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* DynamicIcon Component
|
||||
*
|
||||
* Renders Lucide icons dynamically by name string.
|
||||
* Useful when icon names come from data (e.g., database).
|
||||
*/
|
||||
|
||||
import * as LucideIcons from 'lucide-react';
|
||||
import type { LucideProps } from 'lucide-react';
|
||||
|
||||
/**
|
||||
* Map of icon names to their components.
|
||||
* Uses kebab-case names (e.g., 'clipboard-check') as keys.
|
||||
*/
|
||||
const iconMap: Record<string, React.ComponentType<LucideProps>> = {
|
||||
// Development
|
||||
'clipboard-check': LucideIcons.ClipboardCheck,
|
||||
briefcase: LucideIcons.Briefcase,
|
||||
'file-text': LucideIcons.FileText,
|
||||
'git-branch': LucideIcons.GitBranch,
|
||||
code: LucideIcons.Code,
|
||||
server: LucideIcons.Server,
|
||||
layout: LucideIcons.Layout,
|
||||
smartphone: LucideIcons.Smartphone,
|
||||
// Design
|
||||
palette: LucideIcons.Palette,
|
||||
search: LucideIcons.Search,
|
||||
// Quality
|
||||
shield: LucideIcons.Shield,
|
||||
'shield-check': LucideIcons.ShieldCheck,
|
||||
// Operations
|
||||
settings: LucideIcons.Settings,
|
||||
'settings-2': LucideIcons.Settings2,
|
||||
// AI/ML
|
||||
brain: LucideIcons.Brain,
|
||||
microscope: LucideIcons.Microscope,
|
||||
eye: LucideIcons.Eye,
|
||||
'message-square': LucideIcons.MessageSquare,
|
||||
// Data
|
||||
'bar-chart': LucideIcons.BarChart,
|
||||
database: LucideIcons.Database,
|
||||
// Leadership
|
||||
users: LucideIcons.Users,
|
||||
target: LucideIcons.Target,
|
||||
// Domain Expert
|
||||
calculator: LucideIcons.Calculator,
|
||||
'heart-pulse': LucideIcons.HeartPulse,
|
||||
'flask-conical': LucideIcons.FlaskConical,
|
||||
lightbulb: LucideIcons.Lightbulb,
|
||||
'book-open': LucideIcons.BookOpen,
|
||||
// Generic
|
||||
bot: LucideIcons.Bot,
|
||||
cpu: LucideIcons.Cpu,
|
||||
};
|
||||
|
||||
interface DynamicIconProps extends Omit<LucideProps, 'name'> {
|
||||
/** Icon name in kebab-case (e.g., 'clipboard-check', 'bot') */
|
||||
name: string | null | undefined;
|
||||
/** Fallback icon name if the specified icon is not found */
|
||||
fallback?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Renders a Lucide icon dynamically by name.
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* <DynamicIcon name="clipboard-check" className="h-5 w-5" />
|
||||
* <DynamicIcon name={agent.icon} fallback="bot" />
|
||||
* ```
|
||||
*/
|
||||
export function DynamicIcon({ name, fallback = 'bot', ...props }: DynamicIconProps) {
|
||||
const iconName = name || fallback;
|
||||
const IconComponent = iconMap[iconName] || iconMap[fallback] || LucideIcons.Bot;
|
||||
|
||||
return <IconComponent {...props} />;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get available icon names for validation or display
|
||||
*/
|
||||
export function getAvailableIconNames(): string[] {
|
||||
return Object.keys(iconMap);
|
||||
}
|
||||
93
frontend/src/components/ui/toggle-group.tsx
Normal file
93
frontend/src/components/ui/toggle-group.tsx
Normal file
@@ -0,0 +1,93 @@
|
||||
'use client';
|
||||
|
||||
import * as React from 'react';
|
||||
import * as ToggleGroupPrimitive from '@radix-ui/react-toggle-group';
|
||||
import { type VariantProps, cva } from 'class-variance-authority';
|
||||
|
||||
import { cn } from '@/lib/utils';
|
||||
|
||||
const toggleGroupVariants = cva(
|
||||
'inline-flex items-center justify-center rounded-md border bg-transparent',
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
default: 'bg-transparent',
|
||||
outline: 'border border-input',
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
variant: 'outline',
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const toggleGroupItemVariants = cva(
|
||||
'inline-flex items-center justify-center whitespace-nowrap text-sm font-medium ring-offset-background transition-all focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 data-[state=on]:bg-accent data-[state=on]:text-accent-foreground',
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
default: 'bg-transparent hover:bg-muted hover:text-muted-foreground',
|
||||
outline: 'bg-transparent hover:bg-accent hover:text-accent-foreground',
|
||||
},
|
||||
size: {
|
||||
default: 'h-10 px-3',
|
||||
sm: 'h-9 px-2.5',
|
||||
lg: 'h-11 px-5',
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
variant: 'default',
|
||||
size: 'default',
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const ToggleGroupContext = React.createContext<VariantProps<typeof toggleGroupItemVariants>>({
|
||||
size: 'default',
|
||||
variant: 'default',
|
||||
});
|
||||
|
||||
const ToggleGroup = React.forwardRef<
|
||||
React.ElementRef<typeof ToggleGroupPrimitive.Root>,
|
||||
React.ComponentPropsWithoutRef<typeof ToggleGroupPrimitive.Root> &
|
||||
VariantProps<typeof toggleGroupVariants> &
|
||||
VariantProps<typeof toggleGroupItemVariants>
|
||||
>(({ className, variant, size, children, ...props }, ref) => (
|
||||
<ToggleGroupPrimitive.Root
|
||||
ref={ref}
|
||||
className={cn(toggleGroupVariants({ variant }), className)}
|
||||
{...props}
|
||||
>
|
||||
<ToggleGroupContext.Provider value={{ variant, size }}>{children}</ToggleGroupContext.Provider>
|
||||
</ToggleGroupPrimitive.Root>
|
||||
));
|
||||
|
||||
ToggleGroup.displayName = ToggleGroupPrimitive.Root.displayName;
|
||||
|
||||
const ToggleGroupItem = React.forwardRef<
|
||||
React.ElementRef<typeof ToggleGroupPrimitive.Item>,
|
||||
React.ComponentPropsWithoutRef<typeof ToggleGroupPrimitive.Item> &
|
||||
VariantProps<typeof toggleGroupItemVariants>
|
||||
>(({ className, children, variant, size, ...props }, ref) => {
|
||||
const context = React.useContext(ToggleGroupContext);
|
||||
|
||||
return (
|
||||
<ToggleGroupPrimitive.Item
|
||||
ref={ref}
|
||||
className={cn(
|
||||
toggleGroupItemVariants({
|
||||
variant: context.variant || variant,
|
||||
size: context.size || size,
|
||||
}),
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</ToggleGroupPrimitive.Item>
|
||||
);
|
||||
});
|
||||
|
||||
ToggleGroupItem.displayName = ToggleGroupPrimitive.Item.displayName;
|
||||
|
||||
export { ToggleGroup, ToggleGroupItem };
|
||||
@@ -44,10 +44,10 @@ const DEFAULT_PAGE_LIMIT = 20;
|
||||
export function useAgentTypes(params: AgentTypeListParams = {}) {
|
||||
const { user } = useAuth();
|
||||
|
||||
const { page = 1, limit = DEFAULT_PAGE_LIMIT, is_active = true, search } = params;
|
||||
const { page = 1, limit = DEFAULT_PAGE_LIMIT, is_active = true, search, category } = params;
|
||||
|
||||
return useQuery({
|
||||
queryKey: agentTypeKeys.list({ page, limit, is_active, search }),
|
||||
queryKey: agentTypeKeys.list({ page, limit, is_active, search, category }),
|
||||
queryFn: async (): Promise<AgentTypeListResponse> => {
|
||||
const response = await apiClient.instance.get('/api/v1/agent-types', {
|
||||
params: {
|
||||
@@ -55,6 +55,7 @@ export function useAgentTypes(params: AgentTypeListParams = {}) {
|
||||
limit,
|
||||
is_active,
|
||||
...(search ? { search } : {}),
|
||||
...(category ? { category } : {}),
|
||||
},
|
||||
});
|
||||
return response.data;
|
||||
|
||||
@@ -5,6 +5,68 @@
|
||||
* Used for type-safe API communication with the agent-types endpoints.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Category classification for agent types
|
||||
*/
|
||||
export type AgentTypeCategory =
|
||||
| 'development'
|
||||
| 'design'
|
||||
| 'quality'
|
||||
| 'operations'
|
||||
| 'ai_ml'
|
||||
| 'data'
|
||||
| 'leadership'
|
||||
| 'domain_expert';
|
||||
|
||||
/**
|
||||
* Metadata for each category including display label and description
|
||||
*/
|
||||
export const CATEGORY_METADATA: Record<
|
||||
AgentTypeCategory,
|
||||
{ label: string; description: string; color: string }
|
||||
> = {
|
||||
development: {
|
||||
label: 'Development',
|
||||
description: 'Product, project, and engineering roles',
|
||||
color: '#3B82F6',
|
||||
},
|
||||
design: {
|
||||
label: 'Design',
|
||||
description: 'UI/UX and design research',
|
||||
color: '#EC4899',
|
||||
},
|
||||
quality: {
|
||||
label: 'Quality',
|
||||
description: 'QA and security assurance',
|
||||
color: '#10B981',
|
||||
},
|
||||
operations: {
|
||||
label: 'Operations',
|
||||
description: 'DevOps and MLOps engineering',
|
||||
color: '#F59E0B',
|
||||
},
|
||||
ai_ml: {
|
||||
label: 'AI & ML',
|
||||
description: 'Machine learning specialists',
|
||||
color: '#8B5CF6',
|
||||
},
|
||||
data: {
|
||||
label: 'Data',
|
||||
description: 'Data science and engineering',
|
||||
color: '#06B6D4',
|
||||
},
|
||||
leadership: {
|
||||
label: 'Leadership',
|
||||
description: 'Technical leadership and facilitation',
|
||||
color: '#F97316',
|
||||
},
|
||||
domain_expert: {
|
||||
label: 'Domain Experts',
|
||||
description: 'Industry and domain specialists',
|
||||
color: '#84CC16',
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Base agent type fields shared across create, update, and response schemas
|
||||
*/
|
||||
@@ -20,6 +82,13 @@ export interface AgentTypeBase {
|
||||
mcp_servers: string[];
|
||||
tool_permissions: Record<string, unknown>;
|
||||
is_active: boolean;
|
||||
// Category and display fields
|
||||
category?: AgentTypeCategory | null;
|
||||
icon?: string | null;
|
||||
color?: string | null;
|
||||
sort_order: number;
|
||||
typical_tasks: string[];
|
||||
collaboration_hints: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -37,6 +106,13 @@ export interface AgentTypeCreate {
|
||||
mcp_servers?: string[];
|
||||
tool_permissions?: Record<string, unknown>;
|
||||
is_active?: boolean;
|
||||
// Category and display fields
|
||||
category?: AgentTypeCategory | null;
|
||||
icon?: string | null;
|
||||
color?: string | null;
|
||||
sort_order?: number;
|
||||
typical_tasks?: string[];
|
||||
collaboration_hints?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -54,6 +130,13 @@ export interface AgentTypeUpdate {
|
||||
mcp_servers?: string[] | null;
|
||||
tool_permissions?: Record<string, unknown> | null;
|
||||
is_active?: boolean | null;
|
||||
// Category and display fields
|
||||
category?: AgentTypeCategory | null;
|
||||
icon?: string | null;
|
||||
color?: string | null;
|
||||
sort_order?: number | null;
|
||||
typical_tasks?: string[] | null;
|
||||
collaboration_hints?: string[] | null;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -72,6 +155,13 @@ export interface AgentTypeResponse {
|
||||
mcp_servers: string[];
|
||||
tool_permissions: Record<string, unknown>;
|
||||
is_active: boolean;
|
||||
// Category and display fields
|
||||
category: AgentTypeCategory | null;
|
||||
icon: string | null;
|
||||
color: string | null;
|
||||
sort_order: number;
|
||||
typical_tasks: string[];
|
||||
collaboration_hints: string[];
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
instance_count: number;
|
||||
@@ -104,9 +194,15 @@ export interface AgentTypeListParams {
|
||||
page?: number;
|
||||
limit?: number;
|
||||
is_active?: boolean;
|
||||
category?: AgentTypeCategory;
|
||||
search?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Response type for grouped agent types by category
|
||||
*/
|
||||
export type AgentTypeGroupedResponse = Record<string, AgentTypeResponse[]>;
|
||||
|
||||
/**
|
||||
* Model parameter configuration with typed fields
|
||||
*/
|
||||
|
||||
@@ -6,12 +6,18 @@
|
||||
*/
|
||||
|
||||
import { z } from 'zod';
|
||||
import type { AgentTypeCategory } from '@/lib/api/types/agentTypes';
|
||||
|
||||
/**
|
||||
* Slug validation regex: lowercase letters, numbers, and hyphens only
|
||||
*/
|
||||
const slugRegex = /^[a-z0-9-]+$/;
|
||||
|
||||
/**
|
||||
* Hex color validation regex
|
||||
*/
|
||||
const hexColorRegex = /^#[0-9A-Fa-f]{6}$/;
|
||||
|
||||
/**
|
||||
* Available AI models for agent types
|
||||
*/
|
||||
@@ -43,6 +49,84 @@ export const AGENT_TYPE_STATUS = [
|
||||
{ value: false, label: 'Inactive' },
|
||||
] as const;
|
||||
|
||||
/**
|
||||
* Agent type categories for organizing agents
|
||||
*/
|
||||
/* istanbul ignore next -- constant declaration */
|
||||
export const AGENT_TYPE_CATEGORIES: {
|
||||
value: AgentTypeCategory;
|
||||
label: string;
|
||||
description: string;
|
||||
}[] = [
|
||||
{ value: 'development', label: 'Development', description: 'Product, project, and engineering' },
|
||||
{ value: 'design', label: 'Design', description: 'UI/UX and design research' },
|
||||
{ value: 'quality', label: 'Quality', description: 'QA and security assurance' },
|
||||
{ value: 'operations', label: 'Operations', description: 'DevOps and MLOps engineering' },
|
||||
{ value: 'ai_ml', label: 'AI & ML', description: 'Machine learning specialists' },
|
||||
{ value: 'data', label: 'Data', description: 'Data science and engineering' },
|
||||
{ value: 'leadership', label: 'Leadership', description: 'Technical leadership' },
|
||||
{ value: 'domain_expert', label: 'Domain Experts', description: 'Industry specialists' },
|
||||
];
|
||||
|
||||
/**
|
||||
* Available Lucide icons for agent types
|
||||
*/
|
||||
/* istanbul ignore next -- constant declaration */
|
||||
export const AVAILABLE_ICONS = [
|
||||
// Development
|
||||
{ value: 'clipboard-check', label: 'Clipboard Check', category: 'development' },
|
||||
{ value: 'briefcase', label: 'Briefcase', category: 'development' },
|
||||
{ value: 'file-text', label: 'File Text', category: 'development' },
|
||||
{ value: 'git-branch', label: 'Git Branch', category: 'development' },
|
||||
{ value: 'code', label: 'Code', category: 'development' },
|
||||
{ value: 'server', label: 'Server', category: 'development' },
|
||||
{ value: 'layout', label: 'Layout', category: 'development' },
|
||||
{ value: 'smartphone', label: 'Smartphone', category: 'development' },
|
||||
// Design
|
||||
{ value: 'palette', label: 'Palette', category: 'design' },
|
||||
{ value: 'search', label: 'Search', category: 'design' },
|
||||
// Quality
|
||||
{ value: 'shield', label: 'Shield', category: 'quality' },
|
||||
{ value: 'shield-check', label: 'Shield Check', category: 'quality' },
|
||||
// Operations
|
||||
{ value: 'settings', label: 'Settings', category: 'operations' },
|
||||
{ value: 'settings-2', label: 'Settings 2', category: 'operations' },
|
||||
// AI/ML
|
||||
{ value: 'brain', label: 'Brain', category: 'ai_ml' },
|
||||
{ value: 'microscope', label: 'Microscope', category: 'ai_ml' },
|
||||
{ value: 'eye', label: 'Eye', category: 'ai_ml' },
|
||||
{ value: 'message-square', label: 'Message Square', category: 'ai_ml' },
|
||||
// Data
|
||||
{ value: 'bar-chart', label: 'Bar Chart', category: 'data' },
|
||||
{ value: 'database', label: 'Database', category: 'data' },
|
||||
// Leadership
|
||||
{ value: 'users', label: 'Users', category: 'leadership' },
|
||||
{ value: 'target', label: 'Target', category: 'leadership' },
|
||||
// Domain Expert
|
||||
{ value: 'calculator', label: 'Calculator', category: 'domain_expert' },
|
||||
{ value: 'heart-pulse', label: 'Heart Pulse', category: 'domain_expert' },
|
||||
{ value: 'flask-conical', label: 'Flask', category: 'domain_expert' },
|
||||
{ value: 'lightbulb', label: 'Lightbulb', category: 'domain_expert' },
|
||||
{ value: 'book-open', label: 'Book Open', category: 'domain_expert' },
|
||||
// Generic
|
||||
{ value: 'bot', label: 'Bot', category: 'generic' },
|
||||
] as const;
|
||||
|
||||
/**
|
||||
* Color palette for agent type visual distinction
|
||||
*/
|
||||
/* istanbul ignore next -- constant declaration */
|
||||
export const COLOR_PALETTE = [
|
||||
{ value: '#3B82F6', label: 'Blue', category: 'development' },
|
||||
{ value: '#EC4899', label: 'Pink', category: 'design' },
|
||||
{ value: '#10B981', label: 'Green', category: 'quality' },
|
||||
{ value: '#F59E0B', label: 'Amber', category: 'operations' },
|
||||
{ value: '#8B5CF6', label: 'Purple', category: 'ai_ml' },
|
||||
{ value: '#06B6D4', label: 'Cyan', category: 'data' },
|
||||
{ value: '#F97316', label: 'Orange', category: 'leadership' },
|
||||
{ value: '#84CC16', label: 'Lime', category: 'domain_expert' },
|
||||
] as const;
|
||||
|
||||
/**
|
||||
* Model params schema
|
||||
*/
|
||||
@@ -52,6 +136,20 @@ const modelParamsSchema = z.object({
|
||||
top_p: z.number().min(0).max(1),
|
||||
});
|
||||
|
||||
/**
|
||||
* Agent type category enum values
|
||||
*/
|
||||
const agentTypeCategoryValues = [
|
||||
'development',
|
||||
'design',
|
||||
'quality',
|
||||
'operations',
|
||||
'ai_ml',
|
||||
'data',
|
||||
'leadership',
|
||||
'domain_expert',
|
||||
] as const;
|
||||
|
||||
/**
|
||||
* Schema for agent type form fields
|
||||
*/
|
||||
@@ -96,6 +194,23 @@ export const agentTypeFormSchema = z.object({
|
||||
tool_permissions: z.record(z.string(), z.unknown()),
|
||||
|
||||
is_active: z.boolean(),
|
||||
|
||||
// Category and display fields
|
||||
category: z.enum(agentTypeCategoryValues).nullable().optional(),
|
||||
|
||||
icon: z.string().max(50, 'Icon must be less than 50 characters').nullable().optional(),
|
||||
|
||||
color: z
|
||||
.string()
|
||||
.regex(hexColorRegex, 'Color must be a valid hex code (e.g., #3B82F6)')
|
||||
.nullable()
|
||||
.optional(),
|
||||
|
||||
sort_order: z.number().int().min(0).max(1000),
|
||||
|
||||
typical_tasks: z.array(z.string()),
|
||||
|
||||
collaboration_hints: z.array(z.string()),
|
||||
});
|
||||
|
||||
/**
|
||||
@@ -138,6 +253,13 @@ export const defaultAgentTypeValues: AgentTypeCreateFormValues = {
|
||||
mcp_servers: [],
|
||||
tool_permissions: {},
|
||||
is_active: false, // Start as draft
|
||||
// Category and display fields
|
||||
category: null,
|
||||
icon: 'bot',
|
||||
color: '#3B82F6',
|
||||
sort_order: 0,
|
||||
typical_tasks: [],
|
||||
collaboration_hints: [],
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -21,6 +21,13 @@ Your approach is:
|
||||
mcp_servers: ['gitea', 'knowledge', 'filesystem'],
|
||||
tool_permissions: {},
|
||||
is_active: true,
|
||||
// Category and display fields
|
||||
category: 'development',
|
||||
icon: 'git-branch',
|
||||
color: '#3B82F6',
|
||||
sort_order: 40,
|
||||
typical_tasks: ['Design system architecture', 'Create ADRs'],
|
||||
collaboration_hints: ['backend-engineer', 'frontend-engineer'],
|
||||
created_at: '2025-01-10T00:00:00Z',
|
||||
updated_at: '2025-01-18T00:00:00Z',
|
||||
instance_count: 2,
|
||||
@@ -58,9 +65,8 @@ describe('AgentTypeDetail', () => {
|
||||
expect(screen.getByText('Inactive')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders description card', () => {
|
||||
it('renders description in hero header', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText('Description')).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText('Designs system architecture and makes technology decisions')
|
||||
).toBeInTheDocument();
|
||||
@@ -130,7 +136,7 @@ describe('AgentTypeDetail', () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
|
||||
await user.click(screen.getByRole('button', { name: /go back/i }));
|
||||
await user.click(screen.getByRole('button', { name: /back to agent types/i }));
|
||||
expect(defaultProps.onBack).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
@@ -211,4 +217,146 @@ describe('AgentTypeDetail', () => {
|
||||
);
|
||||
expect(screen.getByText('None configured')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
describe('Hero Header', () => {
|
||||
it('renders hero header with agent name', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(
|
||||
screen.getByRole('heading', { level: 1, name: 'Software Architect' })
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders dynamic icon in hero header', () => {
|
||||
const { container } = render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(container.querySelector('svg.lucide-git-branch')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('applies agent color to hero header gradient', () => {
|
||||
const { container } = render(<AgentTypeDetail {...defaultProps} />);
|
||||
const heroHeader = container.querySelector('[style*="linear-gradient"]');
|
||||
expect(heroHeader).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders category badge in hero header', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText('Development')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows last updated date in hero header', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText(/Last updated:/)).toBeInTheDocument();
|
||||
expect(screen.getByText(/Jan 18, 2025/)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Typical Tasks Card', () => {
|
||||
it('renders "What This Agent Does Best" card', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText('What This Agent Does Best')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('displays all typical tasks', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText('Design system architecture')).toBeInTheDocument();
|
||||
expect(screen.getByText('Create ADRs')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('does not render typical tasks card when empty', () => {
|
||||
render(
|
||||
<AgentTypeDetail {...defaultProps} agentType={{ ...mockAgentType, typical_tasks: [] }} />
|
||||
);
|
||||
expect(screen.queryByText('What This Agent Does Best')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Collaboration Hints Card', () => {
|
||||
it('renders "Works Well With" card', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText('Works Well With')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('displays collaboration hints as badges', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText('backend-engineer')).toBeInTheDocument();
|
||||
expect(screen.getByText('frontend-engineer')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('does not render collaboration hints card when empty', () => {
|
||||
render(
|
||||
<AgentTypeDetail
|
||||
{...defaultProps}
|
||||
agentType={{ ...mockAgentType, collaboration_hints: [] }}
|
||||
/>
|
||||
);
|
||||
expect(screen.queryByText('Works Well With')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Category Badge', () => {
|
||||
it('renders category badge with correct label', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText('Development')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('does not render category badge when category is null', () => {
|
||||
render(
|
||||
<AgentTypeDetail {...defaultProps} agentType={{ ...mockAgentType, category: null }} />
|
||||
);
|
||||
// Should not have a Development badge in the hero header area
|
||||
// The word "Development" should not appear
|
||||
expect(screen.queryByText('Development')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Details Card', () => {
|
||||
it('renders details card with slug', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText('Slug')).toBeInTheDocument();
|
||||
expect(screen.getByText('software-architect')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders details card with sort order', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText('Sort Order')).toBeInTheDocument();
|
||||
expect(screen.getByText('40')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders details card with creation date', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
expect(screen.getByText('Created')).toBeInTheDocument();
|
||||
expect(screen.getByText(/Jan 10, 2025/)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Dynamic Icon', () => {
|
||||
it('renders fallback icon when icon is null', () => {
|
||||
const { container } = render(
|
||||
<AgentTypeDetail {...defaultProps} agentType={{ ...mockAgentType, icon: null }} />
|
||||
);
|
||||
// Should fall back to 'bot' icon
|
||||
expect(container.querySelector('svg.lucide-bot')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders correct icon based on agent type', () => {
|
||||
const agentWithBrainIcon = { ...mockAgentType, icon: 'brain' };
|
||||
const { container } = render(
|
||||
<AgentTypeDetail {...defaultProps} agentType={agentWithBrainIcon} />
|
||||
);
|
||||
expect(container.querySelector('svg.lucide-brain')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Color Styling', () => {
|
||||
it('applies custom color to instance count', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} />);
|
||||
const instanceCount = screen.getByText('2');
|
||||
expect(instanceCount).toHaveStyle({ color: 'rgb(59, 130, 246)' });
|
||||
});
|
||||
|
||||
it('uses default color when color is null', () => {
|
||||
render(<AgentTypeDetail {...defaultProps} agentType={{ ...mockAgentType, color: null }} />);
|
||||
// Should still render without errors
|
||||
expect(screen.getByText('Software Architect')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -16,6 +16,13 @@ const mockAgentType: AgentTypeResponse = {
|
||||
mcp_servers: ['gitea'],
|
||||
tool_permissions: {},
|
||||
is_active: true,
|
||||
// Category and display fields
|
||||
category: 'development',
|
||||
icon: 'git-branch',
|
||||
color: '#3B82F6',
|
||||
sort_order: 40,
|
||||
typical_tasks: ['Design system architecture'],
|
||||
collaboration_hints: ['backend-engineer'],
|
||||
created_at: '2025-01-10T00:00:00Z',
|
||||
updated_at: '2025-01-18T00:00:00Z',
|
||||
instance_count: 2,
|
||||
@@ -192,7 +199,8 @@ describe('AgentTypeForm', () => {
|
||||
|
||||
const expertiseInput = screen.getByPlaceholderText(/e.g., system design/i);
|
||||
await user.type(expertiseInput, 'new skill');
|
||||
await user.click(screen.getByRole('button', { name: /^add$/i }));
|
||||
// Click the first "Add" button (for expertise)
|
||||
await user.click(screen.getAllByRole('button', { name: /^add$/i })[0]);
|
||||
|
||||
expect(screen.getByText('new skill')).toBeInTheDocument();
|
||||
});
|
||||
@@ -454,7 +462,8 @@ describe('AgentTypeForm', () => {
|
||||
// Agent type already has 'system design'
|
||||
const expertiseInput = screen.getByPlaceholderText(/e.g., system design/i);
|
||||
await user.type(expertiseInput, 'system design');
|
||||
await user.click(screen.getByRole('button', { name: /^add$/i }));
|
||||
// Click the first "Add" button (for expertise)
|
||||
await user.click(screen.getAllByRole('button', { name: /^add$/i })[0]);
|
||||
|
||||
// Should still only have one 'system design' badge
|
||||
const badges = screen.getAllByText('system design');
|
||||
@@ -465,7 +474,8 @@ describe('AgentTypeForm', () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
|
||||
const addButton = screen.getByRole('button', { name: /^add$/i });
|
||||
// Click the first "Add" button (for expertise)
|
||||
const addButton = screen.getAllByRole('button', { name: /^add$/i })[0];
|
||||
await user.click(addButton);
|
||||
|
||||
// No badges should be added
|
||||
@@ -478,7 +488,8 @@ describe('AgentTypeForm', () => {
|
||||
|
||||
const expertiseInput = screen.getByPlaceholderText(/e.g., system design/i);
|
||||
await user.type(expertiseInput, 'API Design');
|
||||
await user.click(screen.getByRole('button', { name: /^add$/i }));
|
||||
// Click the first "Add" button (for expertise)
|
||||
await user.click(screen.getAllByRole('button', { name: /^add$/i })[0]);
|
||||
|
||||
expect(screen.getByText('api design')).toBeInTheDocument();
|
||||
});
|
||||
@@ -489,7 +500,8 @@ describe('AgentTypeForm', () => {
|
||||
|
||||
const expertiseInput = screen.getByPlaceholderText(/e.g., system design/i);
|
||||
await user.type(expertiseInput, ' testing ');
|
||||
await user.click(screen.getByRole('button', { name: /^add$/i }));
|
||||
// Click the first "Add" button (for expertise)
|
||||
await user.click(screen.getAllByRole('button', { name: /^add$/i })[0]);
|
||||
|
||||
expect(screen.getByText('testing')).toBeInTheDocument();
|
||||
});
|
||||
@@ -502,7 +514,8 @@ describe('AgentTypeForm', () => {
|
||||
/e.g., system design/i
|
||||
) as HTMLInputElement;
|
||||
await user.type(expertiseInput, 'new skill');
|
||||
await user.click(screen.getByRole('button', { name: /^add$/i }));
|
||||
// Click the first "Add" button (for expertise)
|
||||
await user.click(screen.getAllByRole('button', { name: /^add$/i })[0]);
|
||||
|
||||
expect(expertiseInput.value).toBe('');
|
||||
});
|
||||
@@ -562,4 +575,213 @@ describe('AgentTypeForm', () => {
|
||||
expect(screen.getByText('Edit Agent Type')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Category & Display Fields', () => {
|
||||
it('renders category and display section', () => {
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
expect(screen.getByText('Category & Display')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows category select', () => {
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
expect(screen.getByLabelText(/category/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows sort order input', () => {
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
expect(screen.getByLabelText(/sort order/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows icon input', () => {
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
expect(screen.getByLabelText(/icon/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows color input', () => {
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
expect(screen.getByLabelText(/color/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('pre-fills category fields in edit mode', () => {
|
||||
render(<AgentTypeForm {...defaultProps} agentType={mockAgentType} />);
|
||||
|
||||
const iconInput = screen.getByLabelText(/icon/i) as HTMLInputElement;
|
||||
expect(iconInput.value).toBe('git-branch');
|
||||
|
||||
const sortOrderInput = screen.getByLabelText(/sort order/i) as HTMLInputElement;
|
||||
expect(sortOrderInput.value).toBe('40');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Typical Tasks Management', () => {
|
||||
it('shows typical tasks section', () => {
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
expect(screen.getByText('Typical Tasks')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('adds typical task when add button is clicked', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
|
||||
const taskInput = screen.getByPlaceholderText(/e.g., design system architecture/i);
|
||||
await user.type(taskInput, 'Write documentation');
|
||||
// Click the second "Add" button (for typical tasks)
|
||||
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
|
||||
await user.click(addButtons[1]);
|
||||
|
||||
expect(screen.getByText('Write documentation')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('adds typical task on enter key', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
|
||||
const taskInput = screen.getByPlaceholderText(/e.g., design system architecture/i);
|
||||
await user.type(taskInput, 'Write documentation{Enter}');
|
||||
|
||||
expect(screen.getByText('Write documentation')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('removes typical task when X button is clicked', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} agentType={mockAgentType} />);
|
||||
|
||||
// Should have existing typical task
|
||||
expect(screen.getByText('Design system architecture')).toBeInTheDocument();
|
||||
|
||||
// Click remove button
|
||||
const removeButton = screen.getByRole('button', {
|
||||
name: /remove design system architecture/i,
|
||||
});
|
||||
await user.click(removeButton);
|
||||
|
||||
expect(screen.queryByText('Design system architecture')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('does not add duplicate typical tasks', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} agentType={mockAgentType} />);
|
||||
|
||||
// Agent type already has 'Design system architecture'
|
||||
const taskInput = screen.getByPlaceholderText(/e.g., design system architecture/i);
|
||||
await user.type(taskInput, 'Design system architecture');
|
||||
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
|
||||
await user.click(addButtons[1]);
|
||||
|
||||
// Should still only have one badge
|
||||
const badges = screen.getAllByText('Design system architecture');
|
||||
expect(badges).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('does not add empty typical task', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
|
||||
// Click the second "Add" button (for typical tasks) without typing
|
||||
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
|
||||
await user.click(addButtons[1]);
|
||||
|
||||
// No badges should be added (check that there's no remove button for typical tasks)
|
||||
expect(
|
||||
screen.queryByRole('button', { name: /remove write documentation/i })
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Collaboration Hints Management', () => {
|
||||
it('shows collaboration hints section', () => {
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
expect(screen.getByText('Collaboration Hints')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('adds collaboration hint when add button is clicked', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
|
||||
const hintInput = screen.getByPlaceholderText(/e.g., backend-engineer/i);
|
||||
await user.type(hintInput, 'devops-engineer');
|
||||
// Click the third "Add" button (for collaboration hints)
|
||||
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
|
||||
await user.click(addButtons[2]);
|
||||
|
||||
expect(screen.getByText('devops-engineer')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('adds collaboration hint on enter key', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
|
||||
const hintInput = screen.getByPlaceholderText(/e.g., backend-engineer/i);
|
||||
await user.type(hintInput, 'devops-engineer{Enter}');
|
||||
|
||||
expect(screen.getByText('devops-engineer')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('removes collaboration hint when X button is clicked', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} agentType={mockAgentType} />);
|
||||
|
||||
// Should have existing collaboration hint
|
||||
expect(screen.getByText('backend-engineer')).toBeInTheDocument();
|
||||
|
||||
// Click remove button
|
||||
const removeButton = screen.getByRole('button', { name: /remove backend-engineer/i });
|
||||
await user.click(removeButton);
|
||||
|
||||
expect(screen.queryByText('backend-engineer')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('converts collaboration hints to lowercase', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
|
||||
const hintInput = screen.getByPlaceholderText(/e.g., backend-engineer/i);
|
||||
await user.type(hintInput, 'DevOps-Engineer');
|
||||
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
|
||||
await user.click(addButtons[2]);
|
||||
|
||||
expect(screen.getByText('devops-engineer')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('does not add duplicate collaboration hints', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} agentType={mockAgentType} />);
|
||||
|
||||
// Agent type already has 'backend-engineer'
|
||||
const hintInput = screen.getByPlaceholderText(/e.g., backend-engineer/i);
|
||||
await user.type(hintInput, 'backend-engineer');
|
||||
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
|
||||
await user.click(addButtons[2]);
|
||||
|
||||
// Should still only have one badge
|
||||
const badges = screen.getAllByText('backend-engineer');
|
||||
expect(badges).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('does not add empty collaboration hint', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
|
||||
// Click the third "Add" button (for collaboration hints) without typing
|
||||
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
|
||||
await user.click(addButtons[2]);
|
||||
|
||||
// No badges should be added
|
||||
expect(
|
||||
screen.queryByRole('button', { name: /remove devops-engineer/i })
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('clears input after adding collaboration hint', async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<AgentTypeForm {...defaultProps} />);
|
||||
|
||||
const hintInput = screen.getByPlaceholderText(/e.g., backend-engineer/i) as HTMLInputElement;
|
||||
await user.type(hintInput, 'devops-engineer');
|
||||
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
|
||||
await user.click(addButtons[2]);
|
||||
|
||||
expect(hintInput.value).toBe('');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -17,6 +17,13 @@ const mockAgentTypes: AgentTypeResponse[] = [
|
||||
mcp_servers: ['gitea', 'knowledge'],
|
||||
tool_permissions: {},
|
||||
is_active: true,
|
||||
// Category and display fields
|
||||
category: 'development',
|
||||
icon: 'clipboard-check',
|
||||
color: '#3B82F6',
|
||||
sort_order: 10,
|
||||
typical_tasks: ['Manage backlog', 'Write user stories'],
|
||||
collaboration_hints: ['business-analyst', 'scrum-master'],
|
||||
created_at: '2025-01-15T00:00:00Z',
|
||||
updated_at: '2025-01-20T00:00:00Z',
|
||||
instance_count: 3,
|
||||
@@ -34,6 +41,13 @@ const mockAgentTypes: AgentTypeResponse[] = [
|
||||
mcp_servers: ['gitea'],
|
||||
tool_permissions: {},
|
||||
is_active: false,
|
||||
// Category and display fields
|
||||
category: 'development',
|
||||
icon: 'git-branch',
|
||||
color: '#3B82F6',
|
||||
sort_order: 40,
|
||||
typical_tasks: ['Design architecture', 'Create ADRs'],
|
||||
collaboration_hints: ['backend-engineer', 'devops-engineer'],
|
||||
created_at: '2025-01-10T00:00:00Z',
|
||||
updated_at: '2025-01-18T00:00:00Z',
|
||||
instance_count: 0,
|
||||
@@ -48,6 +62,10 @@ describe('AgentTypeList', () => {
|
||||
onSearchChange: jest.fn(),
|
||||
statusFilter: 'all',
|
||||
onStatusFilterChange: jest.fn(),
|
||||
categoryFilter: 'all',
|
||||
onCategoryFilterChange: jest.fn(),
|
||||
viewMode: 'grid' as const,
|
||||
onViewModeChange: jest.fn(),
|
||||
onSelect: jest.fn(),
|
||||
onCreate: jest.fn(),
|
||||
};
|
||||
@@ -194,4 +212,158 @@ describe('AgentTypeList', () => {
|
||||
const { container } = render(<AgentTypeList {...defaultProps} className="custom-class" />);
|
||||
expect(container.firstChild).toHaveClass('custom-class');
|
||||
});
|
||||
|
||||
describe('Category Filter', () => {
|
||||
it('renders category filter dropdown', () => {
|
||||
render(<AgentTypeList {...defaultProps} />);
|
||||
expect(screen.getByRole('combobox', { name: /filter by category/i })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows "All Categories" as default option', () => {
|
||||
render(<AgentTypeList {...defaultProps} categoryFilter="all" />);
|
||||
expect(screen.getByText('All Categories')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('displays category badge on agent cards', () => {
|
||||
render(<AgentTypeList {...defaultProps} />);
|
||||
// Both agents have 'development' category
|
||||
const developmentBadges = screen.getAllByText('Development');
|
||||
expect(developmentBadges.length).toBe(2);
|
||||
});
|
||||
|
||||
it('shows filter hint in empty state when category filter is applied', () => {
|
||||
render(<AgentTypeList {...defaultProps} agentTypes={[]} categoryFilter="design" />);
|
||||
expect(screen.getByText('Try adjusting your search or filters')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('View Mode Toggle', () => {
|
||||
it('renders view mode toggle buttons', () => {
|
||||
render(<AgentTypeList {...defaultProps} />);
|
||||
expect(screen.getByRole('radio', { name: /grid view/i })).toBeInTheDocument();
|
||||
expect(screen.getByRole('radio', { name: /list view/i })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders grid view by default', () => {
|
||||
const { container } = render(<AgentTypeList {...defaultProps} viewMode="grid" />);
|
||||
// Grid view uses CSS grid
|
||||
expect(container.querySelector('.grid')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders list view when viewMode is list', () => {
|
||||
const { container } = render(<AgentTypeList {...defaultProps} viewMode="list" />);
|
||||
// List view uses space-y-3 for vertical stacking
|
||||
expect(container.querySelector('.space-y-3')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls onViewModeChange when grid toggle is clicked', async () => {
|
||||
const user = userEvent.setup();
|
||||
const onViewModeChange = jest.fn();
|
||||
render(
|
||||
<AgentTypeList {...defaultProps} viewMode="list" onViewModeChange={onViewModeChange} />
|
||||
);
|
||||
|
||||
await user.click(screen.getByRole('radio', { name: /grid view/i }));
|
||||
expect(onViewModeChange).toHaveBeenCalledWith('grid');
|
||||
});
|
||||
|
||||
it('calls onViewModeChange when list toggle is clicked', async () => {
|
||||
const user = userEvent.setup();
|
||||
const onViewModeChange = jest.fn();
|
||||
render(
|
||||
<AgentTypeList {...defaultProps} viewMode="grid" onViewModeChange={onViewModeChange} />
|
||||
);
|
||||
|
||||
await user.click(screen.getByRole('radio', { name: /list view/i }));
|
||||
expect(onViewModeChange).toHaveBeenCalledWith('list');
|
||||
});
|
||||
|
||||
it('shows list-specific loading skeletons when viewMode is list', () => {
|
||||
const { container } = render(
|
||||
<AgentTypeList {...defaultProps} agentTypes={[]} isLoading={true} viewMode="list" />
|
||||
);
|
||||
expect(container.querySelectorAll('.animate-pulse').length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('List View', () => {
|
||||
it('shows agent info in list rows', () => {
|
||||
render(<AgentTypeList {...defaultProps} viewMode="list" />);
|
||||
expect(screen.getByText('Product Owner')).toBeInTheDocument();
|
||||
expect(screen.getByText('Software Architect')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('shows category badge in list view', () => {
|
||||
render(<AgentTypeList {...defaultProps} viewMode="list" />);
|
||||
const developmentBadges = screen.getAllByText('Development');
|
||||
expect(developmentBadges.length).toBe(2);
|
||||
});
|
||||
|
||||
it('shows expertise count in list view', () => {
|
||||
render(<AgentTypeList {...defaultProps} viewMode="list" />);
|
||||
// Both agents have 3 expertise areas
|
||||
const expertiseTexts = screen.getAllByText('3 expertise areas');
|
||||
expect(expertiseTexts.length).toBe(2);
|
||||
});
|
||||
|
||||
it('calls onSelect when list row is clicked', async () => {
|
||||
const user = userEvent.setup();
|
||||
const onSelect = jest.fn();
|
||||
render(<AgentTypeList {...defaultProps} viewMode="list" onSelect={onSelect} />);
|
||||
|
||||
await user.click(screen.getByText('Product Owner'));
|
||||
expect(onSelect).toHaveBeenCalledWith('type-001');
|
||||
});
|
||||
|
||||
it('supports keyboard navigation on list rows', async () => {
|
||||
const user = userEvent.setup();
|
||||
const onSelect = jest.fn();
|
||||
render(<AgentTypeList {...defaultProps} viewMode="list" onSelect={onSelect} />);
|
||||
|
||||
const rows = screen.getAllByRole('button', { name: /view .* agent type/i });
|
||||
rows[0].focus();
|
||||
await user.keyboard('{Enter}');
|
||||
expect(onSelect).toHaveBeenCalledWith('type-001');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Dynamic Icons', () => {
|
||||
it('renders agent icon in grid view', () => {
|
||||
const { container } = render(<AgentTypeList {...defaultProps} viewMode="grid" />);
|
||||
// Check for svg icons with lucide classes
|
||||
const icons = container.querySelectorAll('svg.lucide-clipboard-check, svg.lucide-git-branch');
|
||||
expect(icons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('renders agent icon in list view', () => {
|
||||
const { container } = render(<AgentTypeList {...defaultProps} viewMode="list" />);
|
||||
const icons = container.querySelectorAll('svg.lucide-clipboard-check, svg.lucide-git-branch');
|
||||
expect(icons.length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Color Accent', () => {
|
||||
it('applies color to card border in grid view', () => {
|
||||
const { container } = render(<AgentTypeList {...defaultProps} viewMode="grid" />);
|
||||
const card = container.querySelector('[style*="border-top-color"]');
|
||||
expect(card).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('applies color to row border in list view', () => {
|
||||
const { container } = render(<AgentTypeList {...defaultProps} viewMode="list" />);
|
||||
const row = container.querySelector('[style*="border-left-color"]');
|
||||
expect(row).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Category Badge Component', () => {
|
||||
it('does not render category badge when category is null', () => {
|
||||
const agentWithNoCategory: AgentTypeResponse = {
|
||||
...mockAgentTypes[0],
|
||||
category: null,
|
||||
};
|
||||
render(<AgentTypeList {...defaultProps} agentTypes={[agentWithNoCategory]} />);
|
||||
expect(screen.queryByText('Development')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
448
frontend/tests/components/forms/FormSelect.test.tsx
Normal file
448
frontend/tests/components/forms/FormSelect.test.tsx
Normal file
@@ -0,0 +1,448 @@
|
||||
/**
|
||||
* Tests for FormSelect Component
|
||||
* Verifies select field rendering, accessibility, and error handling
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { render, screen, fireEvent, waitFor } from '@testing-library/react';
|
||||
import { useForm, FormProvider } from 'react-hook-form';
|
||||
import { FormSelect, type SelectOption } from '@/components/forms/FormSelect';
|
||||
|
||||
// Polyfill for Radix UI Select - jsdom doesn't support these browser APIs
|
||||
beforeAll(() => {
|
||||
Element.prototype.hasPointerCapture = jest.fn(() => false);
|
||||
Element.prototype.setPointerCapture = jest.fn();
|
||||
Element.prototype.releasePointerCapture = jest.fn();
|
||||
Element.prototype.scrollIntoView = jest.fn();
|
||||
window.HTMLElement.prototype.scrollIntoView = jest.fn();
|
||||
});
|
||||
|
||||
// Helper wrapper component to provide form context
|
||||
interface TestFormValues {
|
||||
model: string;
|
||||
category: string;
|
||||
}
|
||||
|
||||
function TestWrapper({
|
||||
children,
|
||||
defaultValues = { model: '', category: '' },
|
||||
}: {
|
||||
children: (props: {
|
||||
control: ReturnType<typeof useForm<TestFormValues>>['control'];
|
||||
}) => React.ReactNode;
|
||||
defaultValues?: Partial<TestFormValues>;
|
||||
}) {
|
||||
const form = useForm<TestFormValues>({
|
||||
defaultValues: { model: '', category: '', ...defaultValues },
|
||||
});
|
||||
|
||||
return <FormProvider {...form}>{children({ control: form.control })}</FormProvider>;
|
||||
}
|
||||
|
||||
const mockOptions: SelectOption[] = [
|
||||
{ value: 'claude-opus', label: 'Claude Opus' },
|
||||
{ value: 'claude-sonnet', label: 'Claude Sonnet' },
|
||||
{ value: 'claude-haiku', label: 'Claude Haiku' },
|
||||
];
|
||||
|
||||
describe('FormSelect', () => {
|
||||
describe('Basic Rendering', () => {
|
||||
it('renders with label and select trigger', () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
expect(screen.getByText('Primary Model')).toBeInTheDocument();
|
||||
expect(screen.getByRole('combobox')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders with description', () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
description="Main model used for this agent"
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
expect(screen.getByText('Main model used for this agent')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders with custom placeholder', () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
placeholder="Choose a model"
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
expect(screen.getByText('Choose a model')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders default placeholder when none provided', () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
expect(screen.getByText('Select primary model')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Required Field', () => {
|
||||
it('shows asterisk when required is true', () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
required
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
expect(screen.getByText('*')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('does not show asterisk when required is false', () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
required={false}
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
expect(screen.queryByText('*')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Options Rendering', () => {
|
||||
it('renders all options when opened', async () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
// Open the select using fireEvent (works better with Radix UI)
|
||||
fireEvent.click(screen.getByRole('combobox'));
|
||||
|
||||
// Check all options are rendered
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('option', { name: 'Claude Opus' })).toBeInTheDocument();
|
||||
});
|
||||
expect(screen.getByRole('option', { name: 'Claude Sonnet' })).toBeInTheDocument();
|
||||
expect(screen.getByRole('option', { name: 'Claude Haiku' })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('selects option when clicked', async () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
// Open the select and choose an option
|
||||
fireEvent.click(screen.getByRole('combobox'));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('option', { name: 'Claude Sonnet' })).toBeInTheDocument();
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole('option', { name: 'Claude Sonnet' }));
|
||||
|
||||
// The selected value should now be displayed
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole('combobox')).toHaveTextContent('Claude Sonnet');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Disabled State', () => {
|
||||
it('disables select when disabled prop is true', () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
disabled
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
expect(screen.getByRole('combobox')).toBeDisabled();
|
||||
});
|
||||
|
||||
it('enables select when disabled prop is false', () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
disabled={false}
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
expect(screen.getByRole('combobox')).not.toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Pre-selected Value', () => {
|
||||
it('displays pre-selected value', () => {
|
||||
render(
|
||||
<TestWrapper defaultValues={{ model: 'claude-opus' }}>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
expect(screen.getByRole('combobox')).toHaveTextContent('Claude Opus');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Accessibility', () => {
|
||||
it('links label to select via htmlFor/id', () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
const label = screen.getByText('Primary Model');
|
||||
const select = screen.getByRole('combobox');
|
||||
|
||||
expect(label).toHaveAttribute('for', 'model');
|
||||
expect(select).toHaveAttribute('id', 'model');
|
||||
});
|
||||
|
||||
it('sets aria-describedby with description ID when description exists', () => {
|
||||
render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
description="Choose the main model"
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
const select = screen.getByRole('combobox');
|
||||
expect(select).toHaveAttribute('aria-describedby', 'model-description');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Custom ClassName', () => {
|
||||
it('applies custom className to wrapper', () => {
|
||||
const { container } = render(
|
||||
<TestWrapper>
|
||||
{({ control }) => (
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
className="custom-class"
|
||||
/>
|
||||
)}
|
||||
</TestWrapper>
|
||||
);
|
||||
|
||||
expect(container.querySelector('.custom-class')).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('displays error message when field has error', () => {
|
||||
function TestComponent() {
|
||||
const form = useForm<TestFormValues>({
|
||||
defaultValues: { model: '', category: '' },
|
||||
});
|
||||
|
||||
React.useEffect(() => {
|
||||
form.setError('model', { type: 'required', message: 'Model is required' });
|
||||
}, [form]);
|
||||
|
||||
return (
|
||||
<FormProvider {...form}>
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={form.control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
/>
|
||||
</FormProvider>
|
||||
);
|
||||
}
|
||||
|
||||
render(<TestComponent />);
|
||||
|
||||
expect(screen.getByRole('alert')).toHaveTextContent('Model is required');
|
||||
});
|
||||
|
||||
it('sets aria-invalid when error exists', () => {
|
||||
function TestComponent() {
|
||||
const form = useForm<TestFormValues>({
|
||||
defaultValues: { model: '', category: '' },
|
||||
});
|
||||
|
||||
React.useEffect(() => {
|
||||
form.setError('model', { type: 'required', message: 'Model is required' });
|
||||
}, [form]);
|
||||
|
||||
return (
|
||||
<FormProvider {...form}>
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={form.control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
/>
|
||||
</FormProvider>
|
||||
);
|
||||
}
|
||||
|
||||
render(<TestComponent />);
|
||||
|
||||
expect(screen.getByRole('combobox')).toHaveAttribute('aria-invalid', 'true');
|
||||
});
|
||||
|
||||
it('sets aria-describedby with error ID when error exists', () => {
|
||||
function TestComponent() {
|
||||
const form = useForm<TestFormValues>({
|
||||
defaultValues: { model: '', category: '' },
|
||||
});
|
||||
|
||||
React.useEffect(() => {
|
||||
form.setError('model', { type: 'required', message: 'Model is required' });
|
||||
}, [form]);
|
||||
|
||||
return (
|
||||
<FormProvider {...form}>
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={form.control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
/>
|
||||
</FormProvider>
|
||||
);
|
||||
}
|
||||
|
||||
render(<TestComponent />);
|
||||
|
||||
expect(screen.getByRole('combobox')).toHaveAttribute('aria-describedby', 'model-error');
|
||||
});
|
||||
|
||||
it('combines error and description IDs in aria-describedby', () => {
|
||||
function TestComponent() {
|
||||
const form = useForm<TestFormValues>({
|
||||
defaultValues: { model: '', category: '' },
|
||||
});
|
||||
|
||||
React.useEffect(() => {
|
||||
form.setError('model', { type: 'required', message: 'Model is required' });
|
||||
}, [form]);
|
||||
|
||||
return (
|
||||
<FormProvider {...form}>
|
||||
<FormSelect
|
||||
name="model"
|
||||
control={form.control}
|
||||
label="Primary Model"
|
||||
options={mockOptions}
|
||||
description="Choose the main model"
|
||||
/>
|
||||
</FormProvider>
|
||||
);
|
||||
}
|
||||
|
||||
render(<TestComponent />);
|
||||
|
||||
expect(screen.getByRole('combobox')).toHaveAttribute(
|
||||
'aria-describedby',
|
||||
'model-error model-description'
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
281
frontend/tests/components/forms/FormTextarea.test.tsx
Normal file
281
frontend/tests/components/forms/FormTextarea.test.tsx
Normal file
@@ -0,0 +1,281 @@
|
||||
/**
|
||||
* Tests for FormTextarea Component
|
||||
* Verifies textarea field rendering, accessibility, and error handling
|
||||
*/
|
||||
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import { FormTextarea } from '@/components/forms/FormTextarea';
|
||||
import type { FieldError } from 'react-hook-form';
|
||||
|
||||
describe('FormTextarea', () => {
|
||||
describe('Basic Rendering', () => {
|
||||
it('renders with label and textarea', () => {
|
||||
render(<FormTextarea label="Description" name="description" />);
|
||||
|
||||
expect(screen.getByLabelText('Description')).toBeInTheDocument();
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders with description', () => {
|
||||
render(
|
||||
<FormTextarea
|
||||
label="Personality Prompt"
|
||||
name="personality"
|
||||
description="Define the agent's personality and behavior"
|
||||
/>
|
||||
);
|
||||
|
||||
expect(screen.getByText("Define the agent's personality and behavior")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders description before textarea', () => {
|
||||
const { container } = render(
|
||||
<FormTextarea label="Description" name="description" description="Helper text" />
|
||||
);
|
||||
|
||||
const description = container.querySelector('#description-description');
|
||||
const textarea = container.querySelector('textarea');
|
||||
|
||||
// Get positions
|
||||
const descriptionRect = description?.getBoundingClientRect();
|
||||
const textareaRect = textarea?.getBoundingClientRect();
|
||||
|
||||
// Description should appear (both should exist)
|
||||
expect(description).toBeInTheDocument();
|
||||
expect(textarea).toBeInTheDocument();
|
||||
|
||||
// In the DOM order, description comes before textarea
|
||||
expect(descriptionRect).toBeDefined();
|
||||
expect(textareaRect).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Required Field', () => {
|
||||
it('shows asterisk when required is true', () => {
|
||||
render(<FormTextarea label="Description" name="description" required />);
|
||||
|
||||
expect(screen.getByText('*')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('does not show asterisk when required is false', () => {
|
||||
render(<FormTextarea label="Description" name="description" required={false} />);
|
||||
|
||||
expect(screen.queryByText('*')).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('displays error message when error prop is provided', () => {
|
||||
const error: FieldError = {
|
||||
type: 'required',
|
||||
message: 'Description is required',
|
||||
};
|
||||
|
||||
render(<FormTextarea label="Description" name="description" error={error} />);
|
||||
|
||||
expect(screen.getByText('Description is required')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('sets aria-invalid when error exists', () => {
|
||||
const error: FieldError = {
|
||||
type: 'required',
|
||||
message: 'Description is required',
|
||||
};
|
||||
|
||||
render(<FormTextarea label="Description" name="description" error={error} />);
|
||||
|
||||
const textarea = screen.getByRole('textbox');
|
||||
expect(textarea).toHaveAttribute('aria-invalid', 'true');
|
||||
});
|
||||
|
||||
it('sets aria-describedby with error ID when error exists', () => {
|
||||
const error: FieldError = {
|
||||
type: 'required',
|
||||
message: 'Description is required',
|
||||
};
|
||||
|
||||
render(<FormTextarea label="Description" name="description" error={error} />);
|
||||
|
||||
const textarea = screen.getByRole('textbox');
|
||||
expect(textarea).toHaveAttribute('aria-describedby', 'description-error');
|
||||
});
|
||||
|
||||
it('renders error with role="alert"', () => {
|
||||
const error: FieldError = {
|
||||
type: 'required',
|
||||
message: 'Description is required',
|
||||
};
|
||||
|
||||
render(<FormTextarea label="Description" name="description" error={error} />);
|
||||
|
||||
const errorElement = screen.getByRole('alert');
|
||||
expect(errorElement).toHaveTextContent('Description is required');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Accessibility', () => {
|
||||
it('links label to textarea via htmlFor/id', () => {
|
||||
render(<FormTextarea label="Description" name="description" />);
|
||||
|
||||
const label = screen.getByText('Description');
|
||||
const textarea = screen.getByRole('textbox');
|
||||
|
||||
expect(label).toHaveAttribute('for', 'description');
|
||||
expect(textarea).toHaveAttribute('id', 'description');
|
||||
});
|
||||
|
||||
it('sets aria-describedby with description ID when description exists', () => {
|
||||
render(
|
||||
<FormTextarea
|
||||
label="Description"
|
||||
name="description"
|
||||
description="Enter a detailed description"
|
||||
/>
|
||||
);
|
||||
|
||||
const textarea = screen.getByRole('textbox');
|
||||
expect(textarea).toHaveAttribute('aria-describedby', 'description-description');
|
||||
});
|
||||
|
||||
it('combines error and description IDs in aria-describedby', () => {
|
||||
const error: FieldError = {
|
||||
type: 'required',
|
||||
message: 'Description is required',
|
||||
};
|
||||
|
||||
render(
|
||||
<FormTextarea
|
||||
label="Description"
|
||||
name="description"
|
||||
description="Enter a detailed description"
|
||||
error={error}
|
||||
/>
|
||||
);
|
||||
|
||||
const textarea = screen.getByRole('textbox');
|
||||
expect(textarea).toHaveAttribute(
|
||||
'aria-describedby',
|
||||
'description-error description-description'
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Textarea Props Forwarding', () => {
|
||||
it('forwards textarea props correctly', () => {
|
||||
render(
|
||||
<FormTextarea
|
||||
label="Description"
|
||||
name="description"
|
||||
placeholder="Enter description"
|
||||
rows={5}
|
||||
disabled
|
||||
/>
|
||||
);
|
||||
|
||||
const textarea = screen.getByRole('textbox');
|
||||
expect(textarea).toHaveAttribute('placeholder', 'Enter description');
|
||||
expect(textarea).toHaveAttribute('rows', '5');
|
||||
expect(textarea).toBeDisabled();
|
||||
});
|
||||
|
||||
it('accepts register() props via registration', () => {
|
||||
const registerProps = {
|
||||
name: 'description',
|
||||
onChange: jest.fn(),
|
||||
onBlur: jest.fn(),
|
||||
ref: jest.fn(),
|
||||
};
|
||||
|
||||
render(<FormTextarea label="Description" registration={registerProps} />);
|
||||
|
||||
const textarea = screen.getByRole('textbox');
|
||||
expect(textarea).toBeInTheDocument();
|
||||
expect(textarea).toHaveAttribute('id', 'description');
|
||||
});
|
||||
|
||||
it('extracts name from spread props', () => {
|
||||
const spreadProps = {
|
||||
name: 'content',
|
||||
onChange: jest.fn(),
|
||||
};
|
||||
|
||||
render(<FormTextarea label="Content" {...spreadProps} />);
|
||||
|
||||
const textarea = screen.getByRole('textbox');
|
||||
expect(textarea).toHaveAttribute('id', 'content');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Cases', () => {
|
||||
it('throws error when name is not provided', () => {
|
||||
// Suppress console.error for this test
|
||||
const consoleError = jest.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
expect(() => {
|
||||
render(<FormTextarea label="Description" />);
|
||||
}).toThrow('FormTextarea: name must be provided either explicitly or via register()');
|
||||
|
||||
consoleError.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Layout and Styling', () => {
|
||||
it('applies correct spacing classes', () => {
|
||||
const { container } = render(<FormTextarea label="Description" name="description" />);
|
||||
|
||||
const wrapper = container.firstChild as HTMLElement;
|
||||
expect(wrapper).toHaveClass('space-y-2');
|
||||
});
|
||||
|
||||
it('applies correct error styling', () => {
|
||||
const error: FieldError = {
|
||||
type: 'required',
|
||||
message: 'Description is required',
|
||||
};
|
||||
|
||||
render(<FormTextarea label="Description" name="description" error={error} />);
|
||||
|
||||
const errorElement = screen.getByRole('alert');
|
||||
expect(errorElement).toHaveClass('text-sm', 'text-destructive');
|
||||
});
|
||||
|
||||
it('applies correct description styling', () => {
|
||||
const { container } = render(
|
||||
<FormTextarea label="Description" name="description" description="Helper text" />
|
||||
);
|
||||
|
||||
const description = container.querySelector('#description-description');
|
||||
expect(description).toHaveClass('text-sm', 'text-muted-foreground');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Name Priority', () => {
|
||||
it('uses explicit name over registration name', () => {
|
||||
const registerProps = {
|
||||
name: 'fromRegister',
|
||||
onChange: jest.fn(),
|
||||
onBlur: jest.fn(),
|
||||
ref: jest.fn(),
|
||||
};
|
||||
|
||||
render(<FormTextarea label="Content" name="explicit" registration={registerProps} />);
|
||||
|
||||
const textarea = screen.getByRole('textbox');
|
||||
expect(textarea).toHaveAttribute('id', 'explicit');
|
||||
});
|
||||
|
||||
it('uses registration name when explicit name not provided', () => {
|
||||
const registerProps = {
|
||||
name: 'fromRegister',
|
||||
onChange: jest.fn(),
|
||||
onBlur: jest.fn(),
|
||||
ref: jest.fn(),
|
||||
};
|
||||
|
||||
render(<FormTextarea label="Content" registration={registerProps} />);
|
||||
|
||||
const textarea = screen.getByRole('textbox');
|
||||
expect(textarea).toHaveAttribute('id', 'fromRegister');
|
||||
});
|
||||
});
|
||||
});
|
||||
158
frontend/tests/components/ui/DynamicIcon.test.tsx
Normal file
158
frontend/tests/components/ui/DynamicIcon.test.tsx
Normal file
@@ -0,0 +1,158 @@
|
||||
/**
|
||||
* Tests for DynamicIcon Component
|
||||
* Verifies dynamic icon rendering by name string
|
||||
*/
|
||||
|
||||
import { render, screen } from '@testing-library/react';
|
||||
import { DynamicIcon, getAvailableIconNames } from '@/components/ui/dynamic-icon';
|
||||
|
||||
describe('DynamicIcon', () => {
|
||||
describe('Basic Rendering', () => {
|
||||
it('renders an icon by name', () => {
|
||||
render(<DynamicIcon name="bot" data-testid="icon" />);
|
||||
const icon = screen.getByTestId('icon');
|
||||
expect(icon).toBeInTheDocument();
|
||||
expect(icon.tagName).toBe('svg');
|
||||
});
|
||||
|
||||
it('renders different icons by name', () => {
|
||||
const { rerender } = render(<DynamicIcon name="code" data-testid="icon" />);
|
||||
expect(screen.getByTestId('icon')).toHaveClass('lucide-code');
|
||||
|
||||
rerender(<DynamicIcon name="brain" data-testid="icon" />);
|
||||
expect(screen.getByTestId('icon')).toHaveClass('lucide-brain');
|
||||
|
||||
rerender(<DynamicIcon name="shield" data-testid="icon" />);
|
||||
expect(screen.getByTestId('icon')).toHaveClass('lucide-shield');
|
||||
});
|
||||
|
||||
it('renders kebab-case icon names correctly', () => {
|
||||
render(<DynamicIcon name="clipboard-check" data-testid="icon" />);
|
||||
expect(screen.getByTestId('icon')).toHaveClass('lucide-clipboard-check');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fallback Behavior', () => {
|
||||
it('renders fallback icon when name is null', () => {
|
||||
render(<DynamicIcon name={null} data-testid="icon" />);
|
||||
expect(screen.getByTestId('icon')).toHaveClass('lucide-bot');
|
||||
});
|
||||
|
||||
it('renders fallback icon when name is undefined', () => {
|
||||
render(<DynamicIcon name={undefined} data-testid="icon" />);
|
||||
expect(screen.getByTestId('icon')).toHaveClass('lucide-bot');
|
||||
});
|
||||
|
||||
it('renders fallback icon when name is not found', () => {
|
||||
render(<DynamicIcon name="nonexistent-icon" data-testid="icon" />);
|
||||
expect(screen.getByTestId('icon')).toHaveClass('lucide-bot');
|
||||
});
|
||||
|
||||
it('uses custom fallback when specified', () => {
|
||||
render(<DynamicIcon name={null} fallback="code" data-testid="icon" />);
|
||||
expect(screen.getByTestId('icon')).toHaveClass('lucide-code');
|
||||
});
|
||||
|
||||
it('falls back to bot when custom fallback is also invalid', () => {
|
||||
render(<DynamicIcon name="invalid" fallback="also-invalid" data-testid="icon" />);
|
||||
expect(screen.getByTestId('icon')).toHaveClass('lucide-bot');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Props Forwarding', () => {
|
||||
it('forwards className to icon', () => {
|
||||
render(<DynamicIcon name="bot" className="h-5 w-5 text-primary" data-testid="icon" />);
|
||||
const icon = screen.getByTestId('icon');
|
||||
expect(icon).toHaveClass('h-5');
|
||||
expect(icon).toHaveClass('w-5');
|
||||
expect(icon).toHaveClass('text-primary');
|
||||
});
|
||||
|
||||
it('forwards style to icon', () => {
|
||||
render(<DynamicIcon name="bot" style={{ color: 'red' }} data-testid="icon" />);
|
||||
const icon = screen.getByTestId('icon');
|
||||
expect(icon).toHaveStyle({ color: 'rgb(255, 0, 0)' });
|
||||
});
|
||||
|
||||
it('forwards aria-hidden to icon', () => {
|
||||
render(<DynamicIcon name="bot" aria-hidden="true" data-testid="icon" />);
|
||||
const icon = screen.getByTestId('icon');
|
||||
expect(icon).toHaveAttribute('aria-hidden', 'true');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Available Icons', () => {
|
||||
it('includes development icons', () => {
|
||||
const icons = getAvailableIconNames();
|
||||
expect(icons).toContain('clipboard-check');
|
||||
expect(icons).toContain('briefcase');
|
||||
expect(icons).toContain('code');
|
||||
expect(icons).toContain('server');
|
||||
});
|
||||
|
||||
it('includes design icons', () => {
|
||||
const icons = getAvailableIconNames();
|
||||
expect(icons).toContain('palette');
|
||||
expect(icons).toContain('search');
|
||||
});
|
||||
|
||||
it('includes quality icons', () => {
|
||||
const icons = getAvailableIconNames();
|
||||
expect(icons).toContain('shield');
|
||||
expect(icons).toContain('shield-check');
|
||||
});
|
||||
|
||||
it('includes ai_ml icons', () => {
|
||||
const icons = getAvailableIconNames();
|
||||
expect(icons).toContain('brain');
|
||||
expect(icons).toContain('microscope');
|
||||
expect(icons).toContain('eye');
|
||||
});
|
||||
|
||||
it('includes data icons', () => {
|
||||
const icons = getAvailableIconNames();
|
||||
expect(icons).toContain('bar-chart');
|
||||
expect(icons).toContain('database');
|
||||
});
|
||||
|
||||
it('includes domain expert icons', () => {
|
||||
const icons = getAvailableIconNames();
|
||||
expect(icons).toContain('calculator');
|
||||
expect(icons).toContain('heart-pulse');
|
||||
expect(icons).toContain('flask-conical');
|
||||
expect(icons).toContain('lightbulb');
|
||||
expect(icons).toContain('book-open');
|
||||
});
|
||||
|
||||
it('includes generic icons', () => {
|
||||
const icons = getAvailableIconNames();
|
||||
expect(icons).toContain('bot');
|
||||
expect(icons).toContain('cpu');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Icon Categories Coverage', () => {
|
||||
const iconTestCases = [
|
||||
// Development
|
||||
{ name: 'clipboard-check', expectedClass: 'lucide-clipboard-check' },
|
||||
{ name: 'briefcase', expectedClass: 'lucide-briefcase' },
|
||||
{ name: 'file-text', expectedClass: 'lucide-file-text' },
|
||||
{ name: 'git-branch', expectedClass: 'lucide-git-branch' },
|
||||
{ name: 'layout', expectedClass: 'lucide-panels-top-left' },
|
||||
{ name: 'smartphone', expectedClass: 'lucide-smartphone' },
|
||||
// Operations
|
||||
{ name: 'settings', expectedClass: 'lucide-settings' },
|
||||
{ name: 'settings-2', expectedClass: 'lucide-settings-2' },
|
||||
// AI/ML
|
||||
{ name: 'message-square', expectedClass: 'lucide-message-square' },
|
||||
// Leadership
|
||||
{ name: 'users', expectedClass: 'lucide-users' },
|
||||
{ name: 'target', expectedClass: 'lucide-target' },
|
||||
];
|
||||
|
||||
it.each(iconTestCases)('renders $name icon correctly', ({ name, expectedClass }) => {
|
||||
render(<DynamicIcon name={name} data-testid="icon" />);
|
||||
expect(screen.getByTestId('icon')).toHaveClass(expectedClass);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -27,6 +27,9 @@ jest.mock('@/config/app.config', () => ({
|
||||
debug: {
|
||||
api: false,
|
||||
},
|
||||
demo: {
|
||||
enabled: false,
|
||||
},
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -649,6 +652,9 @@ describe('useProjectEvents', () => {
|
||||
debug: {
|
||||
api: true,
|
||||
},
|
||||
demo: {
|
||||
enabled: false,
|
||||
},
|
||||
},
|
||||
}));
|
||||
|
||||
|
||||
67
mcp-servers/git-ops/Dockerfile
Normal file
67
mcp-servers/git-ops/Dockerfile
Normal file
@@ -0,0 +1,67 @@
|
||||
# Git Operations MCP Server Dockerfile
|
||||
# Multi-stage build for smaller production image
|
||||
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv for fast package management
|
||||
RUN pip install --no-cache-dir uv
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml .
|
||||
|
||||
# Install dependencies with uv
|
||||
RUN uv pip install --system --no-cache .
|
||||
|
||||
# Production stage
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
git \
|
||||
openssh-client \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd --create-home --shell /bin/bash syndarix
|
||||
|
||||
# Create workspace directory
|
||||
RUN mkdir -p /var/syndarix/workspaces && chown -R syndarix:syndarix /var/syndarix
|
||||
|
||||
# Create app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy installed packages from builder
|
||||
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
|
||||
# Copy application code
|
||||
COPY --chown=syndarix:syndarix . .
|
||||
|
||||
# Set Python path
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Configure git for the container
|
||||
RUN git config --global --add safe.directory '*'
|
||||
|
||||
# Switch to non-root user
|
||||
USER syndarix
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8003
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()" || exit 1
|
||||
|
||||
# Run the server
|
||||
CMD ["python", "server.py"]
|
||||
88
mcp-servers/git-ops/Makefile
Normal file
88
mcp-servers/git-ops/Makefile
Normal file
@@ -0,0 +1,88 @@
|
||||
.PHONY: help install install-dev lint lint-fix format format-check type-check test test-cov validate clean run
|
||||
|
||||
# Ensure commands in this project don't inherit an external Python virtualenv
|
||||
# (prevents uv warnings about mismatched VIRTUAL_ENV when running from repo root)
|
||||
unexport VIRTUAL_ENV
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "Git Operations MCP Server - Development Commands"
|
||||
@echo ""
|
||||
@echo "Setup:"
|
||||
@echo " make install - Install production dependencies"
|
||||
@echo " make install-dev - Install development dependencies"
|
||||
@echo ""
|
||||
@echo "Quality Checks:"
|
||||
@echo " make lint - Run Ruff linter"
|
||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||
@echo " make format - Format code with Ruff"
|
||||
@echo " make format-check - Check if code is formatted"
|
||||
@echo " make type-check - Run mypy type checker"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@echo " make test - Run pytest"
|
||||
@echo " make test-cov - Run pytest with coverage"
|
||||
@echo ""
|
||||
@echo "All-in-one:"
|
||||
@echo " make validate - Run all checks (lint + format + types)"
|
||||
@echo ""
|
||||
@echo "Running:"
|
||||
@echo " make run - Run the server locally"
|
||||
@echo ""
|
||||
@echo "Cleanup:"
|
||||
@echo " make clean - Remove cache and build artifacts"
|
||||
|
||||
# Setup
|
||||
install:
|
||||
@echo "Installing production dependencies..."
|
||||
@uv pip install -e .
|
||||
|
||||
install-dev:
|
||||
@echo "Installing development dependencies..."
|
||||
@uv pip install -e ".[dev]"
|
||||
|
||||
# Quality checks
|
||||
lint:
|
||||
@echo "Running Ruff linter..."
|
||||
@uv run ruff check .
|
||||
|
||||
lint-fix:
|
||||
@echo "Running Ruff linter with auto-fix..."
|
||||
@uv run ruff check --fix .
|
||||
|
||||
format:
|
||||
@echo "Formatting code..."
|
||||
@uv run ruff format .
|
||||
|
||||
format-check:
|
||||
@echo "Checking code formatting..."
|
||||
@uv run ruff format --check .
|
||||
|
||||
type-check:
|
||||
@echo "Running mypy..."
|
||||
@uv run python -m mypy server.py config.py models.py exceptions.py git_wrapper.py workspace.py providers/ --explicit-package-bases
|
||||
|
||||
# Testing
|
||||
test:
|
||||
@echo "Running tests..."
|
||||
@IS_TEST=True uv run pytest tests/ -v
|
||||
|
||||
test-cov:
|
||||
@echo "Running tests with coverage..."
|
||||
@IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
|
||||
|
||||
# All-in-one validation
|
||||
validate: lint format-check type-check
|
||||
@echo "All validations passed!"
|
||||
|
||||
# Running
|
||||
run:
|
||||
@echo "Starting Git Operations server..."
|
||||
@uv run python server.py
|
||||
|
||||
# Cleanup
|
||||
clean:
|
||||
@echo "Cleaning up..."
|
||||
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
|
||||
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
@find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
179
mcp-servers/git-ops/__init__.py
Normal file
179
mcp-servers/git-ops/__init__.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Git Operations MCP Server.
|
||||
|
||||
Provides git repository management, branching, commits, and PR workflows
|
||||
for Syndarix AI agents.
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
from config import Settings, get_settings, is_test_mode, reset_settings
|
||||
from exceptions import (
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
BranchExistsError,
|
||||
BranchNotFoundError,
|
||||
CheckoutError,
|
||||
CloneError,
|
||||
CommitError,
|
||||
CredentialError,
|
||||
CredentialNotFoundError,
|
||||
DirtyWorkspaceError,
|
||||
ErrorCode,
|
||||
GitError,
|
||||
GitOpsError,
|
||||
InvalidRefError,
|
||||
MergeConflictError,
|
||||
PRError,
|
||||
PRNotFoundError,
|
||||
ProviderError,
|
||||
ProviderNotFoundError,
|
||||
PullError,
|
||||
PushError,
|
||||
WorkspaceError,
|
||||
WorkspaceLockedError,
|
||||
WorkspaceNotFoundError,
|
||||
WorkspaceSizeExceededError,
|
||||
)
|
||||
from models import (
|
||||
BranchInfo,
|
||||
BranchRequest,
|
||||
BranchResult,
|
||||
CheckoutRequest,
|
||||
CheckoutResult,
|
||||
CloneRequest,
|
||||
CloneResult,
|
||||
CommitInfo,
|
||||
CommitRequest,
|
||||
CommitResult,
|
||||
CreatePRRequest,
|
||||
CreatePRResult,
|
||||
DiffHunk,
|
||||
DiffRequest,
|
||||
DiffResult,
|
||||
FileChange,
|
||||
FileChangeType,
|
||||
FileDiff,
|
||||
GetPRRequest,
|
||||
GetPRResult,
|
||||
GetWorkspaceRequest,
|
||||
GetWorkspaceResult,
|
||||
HealthStatus,
|
||||
ListBranchesRequest,
|
||||
ListBranchesResult,
|
||||
ListPRsRequest,
|
||||
ListPRsResult,
|
||||
LockWorkspaceRequest,
|
||||
LockWorkspaceResult,
|
||||
LogRequest,
|
||||
LogResult,
|
||||
MergePRRequest,
|
||||
MergePRResult,
|
||||
MergeStrategy,
|
||||
PRInfo,
|
||||
ProviderStatus,
|
||||
ProviderType,
|
||||
PRState,
|
||||
PullRequest,
|
||||
PullResult,
|
||||
PushRequest,
|
||||
PushResult,
|
||||
StatusRequest,
|
||||
StatusResult,
|
||||
UnlockWorkspaceRequest,
|
||||
UnlockWorkspaceResult,
|
||||
UpdatePRRequest,
|
||||
UpdatePRResult,
|
||||
WorkspaceInfo,
|
||||
WorkspaceState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Version
|
||||
"__version__",
|
||||
# Config
|
||||
"Settings",
|
||||
"get_settings",
|
||||
"reset_settings",
|
||||
"is_test_mode",
|
||||
# Error codes
|
||||
"ErrorCode",
|
||||
# Exceptions
|
||||
"GitOpsError",
|
||||
"WorkspaceError",
|
||||
"WorkspaceNotFoundError",
|
||||
"WorkspaceLockedError",
|
||||
"WorkspaceSizeExceededError",
|
||||
"GitError",
|
||||
"CloneError",
|
||||
"CheckoutError",
|
||||
"CommitError",
|
||||
"PushError",
|
||||
"PullError",
|
||||
"MergeConflictError",
|
||||
"BranchExistsError",
|
||||
"BranchNotFoundError",
|
||||
"InvalidRefError",
|
||||
"DirtyWorkspaceError",
|
||||
"ProviderError",
|
||||
"AuthenticationError",
|
||||
"ProviderNotFoundError",
|
||||
"PRError",
|
||||
"PRNotFoundError",
|
||||
"APIError",
|
||||
"CredentialError",
|
||||
"CredentialNotFoundError",
|
||||
# Enums
|
||||
"FileChangeType",
|
||||
"MergeStrategy",
|
||||
"PRState",
|
||||
"ProviderType",
|
||||
"WorkspaceState",
|
||||
# Dataclasses
|
||||
"FileChange",
|
||||
"BranchInfo",
|
||||
"CommitInfo",
|
||||
"DiffHunk",
|
||||
"FileDiff",
|
||||
"PRInfo",
|
||||
"WorkspaceInfo",
|
||||
# Request/Response models
|
||||
"CloneRequest",
|
||||
"CloneResult",
|
||||
"StatusRequest",
|
||||
"StatusResult",
|
||||
"BranchRequest",
|
||||
"BranchResult",
|
||||
"ListBranchesRequest",
|
||||
"ListBranchesResult",
|
||||
"CheckoutRequest",
|
||||
"CheckoutResult",
|
||||
"CommitRequest",
|
||||
"CommitResult",
|
||||
"PushRequest",
|
||||
"PushResult",
|
||||
"PullRequest",
|
||||
"PullResult",
|
||||
"DiffRequest",
|
||||
"DiffResult",
|
||||
"LogRequest",
|
||||
"LogResult",
|
||||
"CreatePRRequest",
|
||||
"CreatePRResult",
|
||||
"GetPRRequest",
|
||||
"GetPRResult",
|
||||
"ListPRsRequest",
|
||||
"ListPRsResult",
|
||||
"MergePRRequest",
|
||||
"MergePRResult",
|
||||
"UpdatePRRequest",
|
||||
"UpdatePRResult",
|
||||
"GetWorkspaceRequest",
|
||||
"GetWorkspaceResult",
|
||||
"LockWorkspaceRequest",
|
||||
"LockWorkspaceResult",
|
||||
"UnlockWorkspaceRequest",
|
||||
"UnlockWorkspaceResult",
|
||||
"HealthStatus",
|
||||
"ProviderStatus",
|
||||
]
|
||||
155
mcp-servers/git-ops/config.py
Normal file
155
mcp-servers/git-ops/config.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
Configuration for Git Operations MCP Server.
|
||||
|
||||
Uses pydantic-settings for environment variable loading.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment."""
|
||||
|
||||
# Server settings
|
||||
host: str = Field(default="0.0.0.0", description="Server host")
|
||||
port: int = Field(default=8003, description="Server port")
|
||||
debug: bool = Field(default=False, description="Debug mode")
|
||||
|
||||
# Workspace settings
|
||||
workspace_base_path: Path = Field(
|
||||
default=Path("/var/syndarix/workspaces"),
|
||||
description="Base path for git workspaces",
|
||||
)
|
||||
workspace_max_size_gb: float = Field(
|
||||
default=10.0,
|
||||
description="Maximum size per workspace in GB",
|
||||
)
|
||||
workspace_stale_days: int = Field(
|
||||
default=7,
|
||||
description="Days after which unused workspace is considered stale",
|
||||
)
|
||||
workspace_lock_timeout: int = Field(
|
||||
default=300,
|
||||
description="Workspace lock timeout in seconds",
|
||||
)
|
||||
|
||||
# Git settings
|
||||
git_timeout: int = Field(
|
||||
default=120,
|
||||
description="Default timeout for git operations in seconds",
|
||||
)
|
||||
git_clone_timeout: int = Field(
|
||||
default=600,
|
||||
description="Timeout for clone operations in seconds",
|
||||
)
|
||||
git_author_name: str = Field(
|
||||
default="Syndarix Agent",
|
||||
description="Default author name for commits",
|
||||
)
|
||||
git_author_email: str = Field(
|
||||
default="agent@syndarix.ai",
|
||||
description="Default author email for commits",
|
||||
)
|
||||
git_max_diff_lines: int = Field(
|
||||
default=10000,
|
||||
description="Maximum lines in diff output",
|
||||
)
|
||||
|
||||
# Redis settings (for distributed locking)
|
||||
redis_url: str = Field(
|
||||
default="redis://localhost:6379/0",
|
||||
description="Redis connection URL",
|
||||
)
|
||||
|
||||
# Provider settings
|
||||
gitea_base_url: str = Field(
|
||||
default="",
|
||||
description="Gitea API base URL (e.g., https://gitea.example.com)",
|
||||
)
|
||||
gitea_token: str = Field(
|
||||
default="",
|
||||
description="Gitea API token",
|
||||
)
|
||||
github_token: str = Field(
|
||||
default="",
|
||||
description="GitHub API token",
|
||||
)
|
||||
github_api_url: str = Field(
|
||||
default="https://api.github.com",
|
||||
description="GitHub API URL (for Enterprise)",
|
||||
)
|
||||
gitlab_token: str = Field(
|
||||
default="",
|
||||
description="GitLab API token",
|
||||
)
|
||||
gitlab_url: str = Field(
|
||||
default="https://gitlab.com",
|
||||
description="GitLab URL (for self-hosted)",
|
||||
)
|
||||
|
||||
# Rate limiting
|
||||
rate_limit_requests: int = Field(
|
||||
default=100,
|
||||
description="Max API requests per minute per provider",
|
||||
)
|
||||
rate_limit_window: int = Field(
|
||||
default=60,
|
||||
description="Rate limit window in seconds",
|
||||
)
|
||||
|
||||
# Retry settings
|
||||
retry_attempts: int = Field(
|
||||
default=3,
|
||||
description="Number of retry attempts for failed operations",
|
||||
)
|
||||
retry_delay: float = Field(
|
||||
default=1.0,
|
||||
description="Initial retry delay in seconds",
|
||||
)
|
||||
retry_max_delay: float = Field(
|
||||
default=30.0,
|
||||
description="Maximum retry delay in seconds",
|
||||
)
|
||||
|
||||
# Security settings
|
||||
allowed_hosts: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Allowed git host domains (empty = all)",
|
||||
)
|
||||
max_clone_size_mb: int = Field(
|
||||
default=500,
|
||||
description="Maximum repository size for clone in MB",
|
||||
)
|
||||
enable_force_push: bool = Field(
|
||||
default=False,
|
||||
description="Allow force push operations",
|
||||
)
|
||||
|
||||
model_config = {"env_prefix": "GIT_OPS_", "env_file": ".env", "extra": "ignore"}
|
||||
|
||||
|
||||
# Global settings instance (lazy initialization)
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
"""Get the global settings instance."""
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
|
||||
|
||||
def reset_settings() -> None:
|
||||
"""Reset the global settings (for testing)."""
|
||||
global _settings
|
||||
_settings = None
|
||||
|
||||
|
||||
def is_test_mode() -> bool:
|
||||
"""Check if running in test mode."""
|
||||
return os.getenv("IS_TEST", "").lower() in ("true", "1", "yes")
|
||||
359
mcp-servers/git-ops/exceptions.py
Normal file
359
mcp-servers/git-ops/exceptions.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
Exception hierarchy for Git Operations MCP Server.
|
||||
|
||||
Provides structured error handling with error codes for MCP responses.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ErrorCode(str, Enum):
|
||||
"""Error codes for Git Operations errors."""
|
||||
|
||||
# General errors (1xxx)
|
||||
INTERNAL_ERROR = "GIT_1000"
|
||||
INVALID_REQUEST = "GIT_1001"
|
||||
NOT_FOUND = "GIT_1002"
|
||||
PERMISSION_DENIED = "GIT_1003"
|
||||
TIMEOUT = "GIT_1004"
|
||||
RATE_LIMITED = "GIT_1005"
|
||||
|
||||
# Workspace errors (2xxx)
|
||||
WORKSPACE_NOT_FOUND = "GIT_2000"
|
||||
WORKSPACE_LOCKED = "GIT_2001"
|
||||
WORKSPACE_SIZE_EXCEEDED = "GIT_2002"
|
||||
WORKSPACE_CREATE_FAILED = "GIT_2003"
|
||||
WORKSPACE_DELETE_FAILED = "GIT_2004"
|
||||
|
||||
# Git operation errors (3xxx)
|
||||
CLONE_FAILED = "GIT_3000"
|
||||
CHECKOUT_FAILED = "GIT_3001"
|
||||
COMMIT_FAILED = "GIT_3002"
|
||||
PUSH_FAILED = "GIT_3003"
|
||||
PULL_FAILED = "GIT_3004"
|
||||
MERGE_CONFLICT = "GIT_3005"
|
||||
BRANCH_EXISTS = "GIT_3006"
|
||||
BRANCH_NOT_FOUND = "GIT_3007"
|
||||
INVALID_REF = "GIT_3008"
|
||||
DIRTY_WORKSPACE = "GIT_3009"
|
||||
UNCOMMITTED_CHANGES = "GIT_3010"
|
||||
FETCH_FAILED = "GIT_3011"
|
||||
RESET_FAILED = "GIT_3012"
|
||||
|
||||
# Provider errors (4xxx)
|
||||
PROVIDER_ERROR = "GIT_4000"
|
||||
PROVIDER_AUTH_FAILED = "GIT_4001"
|
||||
PROVIDER_NOT_FOUND = "GIT_4002"
|
||||
PR_CREATE_FAILED = "GIT_4003"
|
||||
PR_MERGE_FAILED = "GIT_4004"
|
||||
PR_NOT_FOUND = "GIT_4005"
|
||||
API_ERROR = "GIT_4006"
|
||||
|
||||
# Credential errors (5xxx)
|
||||
CREDENTIAL_ERROR = "GIT_5000"
|
||||
CREDENTIAL_NOT_FOUND = "GIT_5001"
|
||||
CREDENTIAL_INVALID = "GIT_5002"
|
||||
SSH_KEY_ERROR = "GIT_5003"
|
||||
|
||||
|
||||
class GitOpsError(Exception):
|
||||
"""Base exception for Git Operations errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.INTERNAL_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.details = details or {}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for MCP response."""
|
||||
result: dict[str, Any] = {
|
||||
"error": self.message,
|
||||
"code": self.code.value,
|
||||
}
|
||||
if self.details:
|
||||
result["details"] = self.details
|
||||
return result
|
||||
|
||||
|
||||
# Workspace Errors
|
||||
|
||||
|
||||
class WorkspaceError(GitOpsError):
|
||||
"""Base exception for workspace-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.WORKSPACE_NOT_FOUND,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class WorkspaceNotFoundError(WorkspaceError):
|
||||
"""Workspace does not exist."""
|
||||
|
||||
def __init__(self, project_id: str) -> None:
|
||||
super().__init__(
|
||||
f"Workspace not found for project: {project_id}",
|
||||
ErrorCode.WORKSPACE_NOT_FOUND,
|
||||
{"project_id": project_id},
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceLockedError(WorkspaceError):
|
||||
"""Workspace is locked by another operation."""
|
||||
|
||||
def __init__(self, project_id: str, holder: str | None = None) -> None:
|
||||
details: dict[str, Any] = {"project_id": project_id}
|
||||
if holder:
|
||||
details["locked_by"] = holder
|
||||
super().__init__(
|
||||
f"Workspace is locked for project: {project_id}",
|
||||
ErrorCode.WORKSPACE_LOCKED,
|
||||
details,
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceSizeExceededError(WorkspaceError):
|
||||
"""Workspace size limit exceeded."""
|
||||
|
||||
def __init__(self, project_id: str, current_size: float, max_size: float) -> None:
|
||||
super().__init__(
|
||||
f"Workspace size limit exceeded for project: {project_id}",
|
||||
ErrorCode.WORKSPACE_SIZE_EXCEEDED,
|
||||
{
|
||||
"project_id": project_id,
|
||||
"current_size_gb": current_size,
|
||||
"max_size_gb": max_size,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Git Operation Errors
|
||||
|
||||
|
||||
class GitError(GitOpsError):
|
||||
"""Base exception for git operation errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.INTERNAL_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class CloneError(GitError):
|
||||
"""Failed to clone repository."""
|
||||
|
||||
def __init__(self, repo_url: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Failed to clone repository: {reason}",
|
||||
ErrorCode.CLONE_FAILED,
|
||||
{"repo_url": repo_url, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class CheckoutError(GitError):
|
||||
"""Failed to checkout branch or ref."""
|
||||
|
||||
def __init__(self, ref: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Failed to checkout '{ref}': {reason}",
|
||||
ErrorCode.CHECKOUT_FAILED,
|
||||
{"ref": ref, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class CommitError(GitError):
|
||||
"""Failed to commit changes."""
|
||||
|
||||
def __init__(self, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Failed to commit: {reason}",
|
||||
ErrorCode.COMMIT_FAILED,
|
||||
{"reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class PushError(GitError):
|
||||
"""Failed to push to remote."""
|
||||
|
||||
def __init__(self, branch: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Failed to push branch '{branch}': {reason}",
|
||||
ErrorCode.PUSH_FAILED,
|
||||
{"branch": branch, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class PullError(GitError):
|
||||
"""Failed to pull from remote."""
|
||||
|
||||
def __init__(self, branch: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Failed to pull branch '{branch}': {reason}",
|
||||
ErrorCode.PULL_FAILED,
|
||||
{"branch": branch, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class MergeConflictError(GitError):
|
||||
"""Merge conflict detected."""
|
||||
|
||||
def __init__(self, conflicting_files: list[str]) -> None:
|
||||
super().__init__(
|
||||
f"Merge conflict detected in {len(conflicting_files)} files",
|
||||
ErrorCode.MERGE_CONFLICT,
|
||||
{"conflicting_files": conflicting_files},
|
||||
)
|
||||
|
||||
|
||||
class BranchExistsError(GitError):
|
||||
"""Branch already exists."""
|
||||
|
||||
def __init__(self, branch_name: str) -> None:
|
||||
super().__init__(
|
||||
f"Branch already exists: {branch_name}",
|
||||
ErrorCode.BRANCH_EXISTS,
|
||||
{"branch": branch_name},
|
||||
)
|
||||
|
||||
|
||||
class BranchNotFoundError(GitError):
|
||||
"""Branch does not exist."""
|
||||
|
||||
def __init__(self, branch_name: str) -> None:
|
||||
super().__init__(
|
||||
f"Branch not found: {branch_name}",
|
||||
ErrorCode.BRANCH_NOT_FOUND,
|
||||
{"branch": branch_name},
|
||||
)
|
||||
|
||||
|
||||
class InvalidRefError(GitError):
|
||||
"""Invalid git reference."""
|
||||
|
||||
def __init__(self, ref: str) -> None:
|
||||
super().__init__(
|
||||
f"Invalid git reference: {ref}",
|
||||
ErrorCode.INVALID_REF,
|
||||
{"ref": ref},
|
||||
)
|
||||
|
||||
|
||||
class DirtyWorkspaceError(GitError):
|
||||
"""Workspace has uncommitted changes."""
|
||||
|
||||
def __init__(self, modified_files: list[str]) -> None:
|
||||
super().__init__(
|
||||
f"Workspace has {len(modified_files)} uncommitted changes",
|
||||
ErrorCode.DIRTY_WORKSPACE,
|
||||
{"modified_files": modified_files[:10]}, # Limit to first 10
|
||||
)
|
||||
|
||||
|
||||
# Provider Errors
|
||||
|
||||
|
||||
class ProviderError(GitOpsError):
|
||||
"""Base exception for provider-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.PROVIDER_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class AuthenticationError(ProviderError):
|
||||
"""Authentication with provider failed."""
|
||||
|
||||
def __init__(self, provider: str, reason: str) -> None:
|
||||
super().__init__(
|
||||
f"Authentication failed with {provider}: {reason}",
|
||||
ErrorCode.PROVIDER_AUTH_FAILED,
|
||||
{"provider": provider, "reason": reason},
|
||||
)
|
||||
|
||||
|
||||
class ProviderNotFoundError(ProviderError):
|
||||
"""Provider not configured or recognized."""
|
||||
|
||||
def __init__(self, provider: str) -> None:
|
||||
super().__init__(
|
||||
f"Provider not found or not configured: {provider}",
|
||||
ErrorCode.PROVIDER_NOT_FOUND,
|
||||
{"provider": provider},
|
||||
)
|
||||
|
||||
|
||||
class PRError(ProviderError):
|
||||
"""Pull request operation failed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.PR_CREATE_FAILED,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class PRNotFoundError(PRError):
|
||||
"""Pull request not found."""
|
||||
|
||||
def __init__(self, pr_number: int, repo: str) -> None:
|
||||
super().__init__(
|
||||
f"Pull request #{pr_number} not found in {repo}",
|
||||
ErrorCode.PR_NOT_FOUND,
|
||||
{"pr_number": pr_number, "repo": repo},
|
||||
)
|
||||
|
||||
|
||||
class APIError(ProviderError):
|
||||
"""Provider API error."""
|
||||
|
||||
def __init__(self, provider: str, status_code: int, message: str) -> None:
|
||||
super().__init__(
|
||||
f"{provider} API error ({status_code}): {message}",
|
||||
ErrorCode.API_ERROR,
|
||||
{"provider": provider, "status_code": status_code, "message": message},
|
||||
)
|
||||
|
||||
|
||||
# Credential Errors
|
||||
|
||||
|
||||
class CredentialError(GitOpsError):
|
||||
"""Base exception for credential-related errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
code: ErrorCode = ErrorCode.CREDENTIAL_ERROR,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(message, code, details)
|
||||
|
||||
|
||||
class CredentialNotFoundError(CredentialError):
|
||||
"""Credential not found."""
|
||||
|
||||
def __init__(self, credential_type: str, identifier: str) -> None:
|
||||
super().__init__(
|
||||
f"{credential_type} credential not found: {identifier}",
|
||||
ErrorCode.CREDENTIAL_NOT_FOUND,
|
||||
{"type": credential_type, "identifier": identifier},
|
||||
)
|
||||
1170
mcp-servers/git-ops/git_wrapper.py
Normal file
1170
mcp-servers/git-ops/git_wrapper.py
Normal file
File diff suppressed because it is too large
Load Diff
690
mcp-servers/git-ops/models.py
Normal file
690
mcp-servers/git-ops/models.py
Normal file
@@ -0,0 +1,690 @@
|
||||
"""
|
||||
Data models for Git Operations MCP Server.
|
||||
|
||||
Defines data structures for git operations, workspace management,
|
||||
and provider interactions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FileChangeType(str, Enum):
|
||||
"""Types of file changes in git."""
|
||||
|
||||
ADDED = "added"
|
||||
MODIFIED = "modified"
|
||||
DELETED = "deleted"
|
||||
RENAMED = "renamed"
|
||||
COPIED = "copied"
|
||||
UNTRACKED = "untracked"
|
||||
IGNORED = "ignored"
|
||||
|
||||
|
||||
class MergeStrategy(str, Enum):
|
||||
"""Merge strategies for pull requests."""
|
||||
|
||||
MERGE = "merge" # Create a merge commit
|
||||
SQUASH = "squash" # Squash and merge
|
||||
REBASE = "rebase" # Rebase and merge
|
||||
|
||||
|
||||
class PRState(str, Enum):
|
||||
"""Pull request states."""
|
||||
|
||||
OPEN = "open"
|
||||
CLOSED = "closed"
|
||||
MERGED = "merged"
|
||||
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
"""Supported git providers."""
|
||||
|
||||
GITEA = "gitea"
|
||||
GITHUB = "github"
|
||||
GITLAB = "gitlab"
|
||||
|
||||
|
||||
class WorkspaceState(str, Enum):
|
||||
"""Workspace lifecycle states."""
|
||||
|
||||
INITIALIZING = "initializing"
|
||||
READY = "ready"
|
||||
LOCKED = "locked"
|
||||
STALE = "stale"
|
||||
DELETED = "deleted"
|
||||
|
||||
|
||||
# Dataclasses for internal data structures
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileChange:
|
||||
"""A file change in git status."""
|
||||
|
||||
path: str
|
||||
change_type: FileChangeType
|
||||
old_path: str | None = None # For renames
|
||||
additions: int = 0
|
||||
deletions: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"path": self.path,
|
||||
"change_type": self.change_type.value,
|
||||
"old_path": self.old_path,
|
||||
"additions": self.additions,
|
||||
"deletions": self.deletions,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BranchInfo:
|
||||
"""Information about a git branch."""
|
||||
|
||||
name: str
|
||||
is_current: bool = False
|
||||
is_remote: bool = False
|
||||
tracking_branch: str | None = None
|
||||
commit_sha: str | None = None
|
||||
commit_message: str | None = None
|
||||
ahead: int = 0
|
||||
behind: int = 0
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"name": self.name,
|
||||
"is_current": self.is_current,
|
||||
"is_remote": self.is_remote,
|
||||
"tracking_branch": self.tracking_branch,
|
||||
"commit_sha": self.commit_sha,
|
||||
"commit_message": self.commit_message,
|
||||
"ahead": self.ahead,
|
||||
"behind": self.behind,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitInfo:
|
||||
"""Information about a git commit."""
|
||||
|
||||
sha: str
|
||||
short_sha: str
|
||||
message: str
|
||||
author_name: str
|
||||
author_email: str
|
||||
authored_date: datetime
|
||||
committer_name: str
|
||||
committer_email: str
|
||||
committed_date: datetime
|
||||
parents: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"sha": self.sha,
|
||||
"short_sha": self.short_sha,
|
||||
"message": self.message,
|
||||
"author_name": self.author_name,
|
||||
"author_email": self.author_email,
|
||||
"authored_date": self.authored_date.isoformat(),
|
||||
"committer_name": self.committer_name,
|
||||
"committer_email": self.committer_email,
|
||||
"committed_date": self.committed_date.isoformat(),
|
||||
"parents": self.parents,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffHunk:
|
||||
"""A hunk of diff content."""
|
||||
|
||||
old_start: int
|
||||
old_lines: int
|
||||
new_start: int
|
||||
new_lines: int
|
||||
content: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"old_start": self.old_start,
|
||||
"old_lines": self.old_lines,
|
||||
"new_start": self.new_start,
|
||||
"new_lines": self.new_lines,
|
||||
"content": self.content,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileDiff:
|
||||
"""Diff for a single file."""
|
||||
|
||||
path: str
|
||||
change_type: FileChangeType
|
||||
old_path: str | None = None
|
||||
hunks: list[DiffHunk] = field(default_factory=list)
|
||||
additions: int = 0
|
||||
deletions: int = 0
|
||||
is_binary: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"path": self.path,
|
||||
"change_type": self.change_type.value,
|
||||
"old_path": self.old_path,
|
||||
"hunks": [h.to_dict() for h in self.hunks],
|
||||
"additions": self.additions,
|
||||
"deletions": self.deletions,
|
||||
"is_binary": self.is_binary,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PRInfo:
|
||||
"""Information about a pull request."""
|
||||
|
||||
number: int
|
||||
title: str
|
||||
body: str
|
||||
state: PRState
|
||||
source_branch: str
|
||||
target_branch: str
|
||||
author: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
merged_at: datetime | None = None
|
||||
closed_at: datetime | None = None
|
||||
url: str | None = None
|
||||
labels: list[str] = field(default_factory=list)
|
||||
assignees: list[str] = field(default_factory=list)
|
||||
reviewers: list[str] = field(default_factory=list)
|
||||
mergeable: bool | None = None
|
||||
draft: bool = False
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"number": self.number,
|
||||
"title": self.title,
|
||||
"body": self.body,
|
||||
"state": self.state.value,
|
||||
"source_branch": self.source_branch,
|
||||
"target_branch": self.target_branch,
|
||||
"author": self.author,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"merged_at": self.merged_at.isoformat() if self.merged_at else None,
|
||||
"closed_at": self.closed_at.isoformat() if self.closed_at else None,
|
||||
"url": self.url,
|
||||
"labels": self.labels,
|
||||
"assignees": self.assignees,
|
||||
"reviewers": self.reviewers,
|
||||
"mergeable": self.mergeable,
|
||||
"draft": self.draft,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceInfo:
|
||||
"""Information about a project workspace."""
|
||||
|
||||
project_id: str
|
||||
path: str
|
||||
state: WorkspaceState
|
||||
repo_url: str | None = None
|
||||
current_branch: str | None = None
|
||||
last_accessed: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
size_bytes: int = 0
|
||||
lock_holder: str | None = None
|
||||
lock_expires: datetime | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"project_id": self.project_id,
|
||||
"path": self.path,
|
||||
"state": self.state.value,
|
||||
"repo_url": self.repo_url,
|
||||
"current_branch": self.current_branch,
|
||||
"last_accessed": self.last_accessed.isoformat(),
|
||||
"size_bytes": self.size_bytes,
|
||||
"lock_holder": self.lock_holder,
|
||||
"lock_expires": self.lock_expires.isoformat()
|
||||
if self.lock_expires
|
||||
else None,
|
||||
}
|
||||
|
||||
|
||||
# Pydantic Request/Response Models
|
||||
|
||||
|
||||
class CloneRequest(BaseModel):
|
||||
"""Request to clone a repository."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
repo_url: str = Field(..., description="Repository URL to clone")
|
||||
branch: str | None = Field(
|
||||
default=None, description="Branch to checkout after clone"
|
||||
)
|
||||
depth: int | None = Field(
|
||||
default=None, ge=1, description="Shallow clone depth (None = full clone)"
|
||||
)
|
||||
|
||||
|
||||
class CloneResult(BaseModel):
|
||||
"""Result of a clone operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether clone succeeded")
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
workspace_path: str = Field(..., description="Path to cloned workspace")
|
||||
branch: str = Field(..., description="Current branch after clone")
|
||||
commit_sha: str = Field(..., description="HEAD commit SHA")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class StatusRequest(BaseModel):
|
||||
"""Request for git status."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
include_untracked: bool = Field(default=True, description="Include untracked files")
|
||||
|
||||
|
||||
class StatusResult(BaseModel):
|
||||
"""Result of a status operation."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
branch: str = Field(..., description="Current branch")
|
||||
commit_sha: str = Field(..., description="HEAD commit SHA")
|
||||
is_clean: bool = Field(..., description="Whether working tree is clean")
|
||||
staged: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Staged changes"
|
||||
)
|
||||
unstaged: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Unstaged changes"
|
||||
)
|
||||
untracked: list[str] = Field(default_factory=list, description="Untracked files")
|
||||
ahead: int = Field(default=0, description="Commits ahead of upstream")
|
||||
behind: int = Field(default=0, description="Commits behind upstream")
|
||||
|
||||
|
||||
class BranchRequest(BaseModel):
|
||||
"""Request for branch operations."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
branch_name: str = Field(..., description="Branch name")
|
||||
from_ref: str | None = Field(
|
||||
default=None, description="Reference to create branch from"
|
||||
)
|
||||
checkout: bool = Field(default=True, description="Checkout after creation")
|
||||
|
||||
|
||||
class BranchResult(BaseModel):
|
||||
"""Result of a branch operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether operation succeeded")
|
||||
branch: str = Field(..., description="Branch name")
|
||||
commit_sha: str | None = Field(default=None, description="HEAD commit SHA")
|
||||
is_current: bool = Field(default=False, description="Whether branch is checked out")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class ListBranchesRequest(BaseModel):
|
||||
"""Request to list branches."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
include_remote: bool = Field(default=False, description="Include remote branches")
|
||||
|
||||
|
||||
class ListBranchesResult(BaseModel):
|
||||
"""Result of listing branches."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
current_branch: str = Field(..., description="Currently checked out branch")
|
||||
local_branches: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Local branches"
|
||||
)
|
||||
remote_branches: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Remote branches"
|
||||
)
|
||||
|
||||
|
||||
class CheckoutRequest(BaseModel):
|
||||
"""Request to checkout a branch or ref."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
ref: str = Field(..., description="Branch, tag, or commit to checkout")
|
||||
create_branch: bool = Field(default=False, description="Create new branch")
|
||||
force: bool = Field(default=False, description="Force checkout (discard changes)")
|
||||
|
||||
|
||||
class CheckoutResult(BaseModel):
|
||||
"""Result of a checkout operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether checkout succeeded")
|
||||
ref: str = Field(..., description="Checked out reference")
|
||||
commit_sha: str | None = Field(default=None, description="HEAD commit SHA")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class CommitRequest(BaseModel):
|
||||
"""Request to create a commit."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
message: str = Field(..., description="Commit message")
|
||||
files: list[str] | None = Field(
|
||||
default=None, description="Files to commit (None = all staged)"
|
||||
)
|
||||
author_name: str | None = Field(default=None, description="Author name override")
|
||||
author_email: str | None = Field(default=None, description="Author email override")
|
||||
allow_empty: bool = Field(default=False, description="Allow empty commit")
|
||||
|
||||
|
||||
class CommitResult(BaseModel):
|
||||
"""Result of a commit operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether commit succeeded")
|
||||
commit_sha: str | None = Field(default=None, description="New commit SHA")
|
||||
short_sha: str | None = Field(default=None, description="Short commit SHA")
|
||||
message: str | None = Field(default=None, description="Commit message")
|
||||
files_changed: int = Field(default=0, description="Number of files changed")
|
||||
insertions: int = Field(default=0, description="Lines added")
|
||||
deletions: int = Field(default=0, description="Lines removed")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class PushRequest(BaseModel):
|
||||
"""Request to push to remote."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
branch: str | None = Field(
|
||||
default=None, description="Branch to push (None = current)"
|
||||
)
|
||||
remote: str = Field(default="origin", description="Remote name")
|
||||
force: bool = Field(default=False, description="Force push")
|
||||
set_upstream: bool = Field(default=True, description="Set upstream tracking")
|
||||
|
||||
|
||||
class PushResult(BaseModel):
|
||||
"""Result of a push operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether push succeeded")
|
||||
branch: str = Field(..., description="Pushed branch")
|
||||
remote: str = Field(..., description="Remote name")
|
||||
commits_pushed: int = Field(default=0, description="Number of commits pushed")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class PullRequest(BaseModel):
|
||||
"""Request to pull from remote."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
branch: str | None = Field(
|
||||
default=None, description="Branch to pull (None = current)"
|
||||
)
|
||||
remote: str = Field(default="origin", description="Remote name")
|
||||
rebase: bool = Field(default=False, description="Rebase instead of merge")
|
||||
|
||||
|
||||
class PullResult(BaseModel):
|
||||
"""Result of a pull operation."""
|
||||
|
||||
success: bool = Field(..., description="Whether pull succeeded")
|
||||
branch: str = Field(..., description="Pulled branch")
|
||||
commits_received: int = Field(default=0, description="New commits received")
|
||||
fast_forward: bool = Field(default=False, description="Was fast-forward")
|
||||
conflicts: list[str] = Field(
|
||||
default_factory=list, description="Conflicting files if any"
|
||||
)
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class DiffRequest(BaseModel):
|
||||
"""Request for diff."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
base: str | None = Field(
|
||||
default=None, description="Base reference (None = working tree)"
|
||||
)
|
||||
head: str | None = Field(default=None, description="Head reference (None = HEAD)")
|
||||
files: list[str] | None = Field(default=None, description="Specific files to diff")
|
||||
context_lines: int = Field(default=3, ge=0, description="Context lines")
|
||||
|
||||
|
||||
class DiffResult(BaseModel):
|
||||
"""Result of a diff operation."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
base: str | None = Field(default=None, description="Base reference")
|
||||
head: str | None = Field(default=None, description="Head reference")
|
||||
files: list[dict[str, Any]] = Field(default_factory=list, description="File diffs")
|
||||
total_additions: int = Field(default=0, description="Total lines added")
|
||||
total_deletions: int = Field(default=0, description="Total lines removed")
|
||||
files_changed: int = Field(default=0, description="Number of files changed")
|
||||
|
||||
|
||||
class LogRequest(BaseModel):
|
||||
"""Request for commit log."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
ref: str | None = Field(default=None, description="Reference to start from")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Max commits to return")
|
||||
skip: int = Field(default=0, ge=0, description="Commits to skip")
|
||||
path: str | None = Field(default=None, description="Filter by path")
|
||||
|
||||
|
||||
class LogResult(BaseModel):
|
||||
"""Result of a log operation."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
commits: list[dict[str, Any]] = Field(
|
||||
default_factory=list, description="Commit history"
|
||||
)
|
||||
total_commits: int = Field(default=0, description="Total commits in range")
|
||||
|
||||
|
||||
# PR Operations
|
||||
|
||||
|
||||
class CreatePRRequest(BaseModel):
|
||||
"""Request to create a pull request."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
title: str = Field(..., description="PR title")
|
||||
body: str = Field(default="", description="PR description")
|
||||
source_branch: str = Field(..., description="Source branch")
|
||||
target_branch: str = Field(default="main", description="Target branch")
|
||||
draft: bool = Field(default=False, description="Create as draft")
|
||||
labels: list[str] = Field(default_factory=list, description="Labels to add")
|
||||
assignees: list[str] = Field(default_factory=list, description="Assignees")
|
||||
reviewers: list[str] = Field(default_factory=list, description="Reviewers")
|
||||
|
||||
|
||||
class CreatePRResult(BaseModel):
|
||||
"""Result of creating a pull request."""
|
||||
|
||||
success: bool = Field(..., description="Whether creation succeeded")
|
||||
pr_number: int | None = Field(default=None, description="PR number")
|
||||
pr_url: str | None = Field(default=None, description="PR URL")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class GetPRRequest(BaseModel):
|
||||
"""Request to get a pull request."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
pr_number: int = Field(..., description="PR number")
|
||||
|
||||
|
||||
class GetPRResult(BaseModel):
|
||||
"""Result of getting a pull request."""
|
||||
|
||||
success: bool = Field(..., description="Whether fetch succeeded")
|
||||
pr: dict[str, Any] | None = Field(default=None, description="PR info")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class ListPRsRequest(BaseModel):
|
||||
"""Request to list pull requests."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
state: PRState | None = Field(default=None, description="Filter by state")
|
||||
author: str | None = Field(default=None, description="Filter by author")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Max PRs to return")
|
||||
|
||||
|
||||
class ListPRsResult(BaseModel):
|
||||
"""Result of listing pull requests."""
|
||||
|
||||
success: bool = Field(..., description="Whether list succeeded")
|
||||
pull_requests: list[dict[str, Any]] = Field(default_factory=list, description="PRs")
|
||||
total_count: int = Field(default=0, description="Total matching PRs")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class MergePRRequest(BaseModel):
|
||||
"""Request to merge a pull request."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
pr_number: int = Field(..., description="PR number")
|
||||
merge_strategy: MergeStrategy = Field(
|
||||
default=MergeStrategy.MERGE, description="Merge strategy"
|
||||
)
|
||||
commit_message: str | None = Field(
|
||||
default=None, description="Custom merge commit message"
|
||||
)
|
||||
delete_branch: bool = Field(
|
||||
default=True, description="Delete source branch after merge"
|
||||
)
|
||||
|
||||
|
||||
class MergePRResult(BaseModel):
|
||||
"""Result of merging a pull request."""
|
||||
|
||||
success: bool = Field(..., description="Whether merge succeeded")
|
||||
merge_commit_sha: str | None = Field(default=None, description="Merge commit SHA")
|
||||
branch_deleted: bool = Field(
|
||||
default=False, description="Whether branch was deleted"
|
||||
)
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class UpdatePRRequest(BaseModel):
|
||||
"""Request to update a pull request."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
pr_number: int = Field(..., description="PR number")
|
||||
title: str | None = Field(default=None, description="New title")
|
||||
body: str | None = Field(default=None, description="New description")
|
||||
state: PRState | None = Field(default=None, description="New state")
|
||||
labels: list[str] | None = Field(default=None, description="Replace labels")
|
||||
assignees: list[str] | None = Field(default=None, description="Replace assignees")
|
||||
|
||||
|
||||
class UpdatePRResult(BaseModel):
|
||||
"""Result of updating a pull request."""
|
||||
|
||||
success: bool = Field(..., description="Whether update succeeded")
|
||||
pr: dict[str, Any] | None = Field(default=None, description="Updated PR info")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
# Workspace Operations
|
||||
|
||||
|
||||
class GetWorkspaceRequest(BaseModel):
|
||||
"""Request to get or create workspace."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
|
||||
|
||||
class GetWorkspaceResult(BaseModel):
|
||||
"""Result of getting workspace."""
|
||||
|
||||
success: bool = Field(..., description="Whether operation succeeded")
|
||||
workspace: dict[str, Any] | None = Field(default=None, description="Workspace info")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class LockWorkspaceRequest(BaseModel):
|
||||
"""Request to lock a workspace."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
agent_id: str = Field(..., description="Agent ID requesting lock")
|
||||
timeout: int = Field(
|
||||
default=300, ge=10, le=3600, description="Lock timeout seconds"
|
||||
)
|
||||
|
||||
|
||||
class LockWorkspaceResult(BaseModel):
|
||||
"""Result of locking workspace."""
|
||||
|
||||
success: bool = Field(..., description="Whether lock acquired")
|
||||
lock_holder: str | None = Field(default=None, description="Current lock holder")
|
||||
lock_expires: str | None = Field(
|
||||
default=None, description="Lock expiry ISO timestamp"
|
||||
)
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
class UnlockWorkspaceRequest(BaseModel):
|
||||
"""Request to unlock a workspace."""
|
||||
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
agent_id: str = Field(..., description="Agent ID releasing lock")
|
||||
force: bool = Field(default=False, description="Force unlock (admin only)")
|
||||
|
||||
|
||||
class UnlockWorkspaceResult(BaseModel):
|
||||
"""Result of unlocking workspace."""
|
||||
|
||||
success: bool = Field(..., description="Whether unlock succeeded")
|
||||
error: str | None = Field(default=None, description="Error message if failed")
|
||||
|
||||
|
||||
# Health and Status
|
||||
|
||||
|
||||
class HealthStatus(BaseModel):
|
||||
"""Health status response."""
|
||||
|
||||
status: str = Field(..., description="Health status")
|
||||
version: str = Field(..., description="Server version")
|
||||
workspace_count: int = Field(default=0, description="Active workspaces")
|
||||
gitea_connected: bool = Field(default=False, description="Gitea connectivity")
|
||||
github_connected: bool = Field(default=False, description="GitHub connectivity")
|
||||
gitlab_connected: bool = Field(default=False, description="GitLab connectivity")
|
||||
redis_connected: bool = Field(default=False, description="Redis connectivity")
|
||||
|
||||
|
||||
class ProviderStatus(BaseModel):
|
||||
"""Provider connection status."""
|
||||
|
||||
provider: str = Field(..., description="Provider name")
|
||||
connected: bool = Field(..., description="Connection status")
|
||||
url: str | None = Field(default=None, description="Provider URL")
|
||||
user: str | None = Field(default=None, description="Authenticated user")
|
||||
error: str | None = Field(default=None, description="Error if not connected")
|
||||
11
mcp-servers/git-ops/providers/__init__.py
Normal file
11
mcp-servers/git-ops/providers/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Git provider implementations.
|
||||
|
||||
Provides adapters for different git hosting platforms (Gitea, GitHub, GitLab).
|
||||
"""
|
||||
|
||||
from .base import BaseProvider
|
||||
from .gitea import GiteaProvider
|
||||
from .github import GitHubProvider
|
||||
|
||||
__all__ = ["BaseProvider", "GiteaProvider", "GitHubProvider"]
|
||||
376
mcp-servers/git-ops/providers/base.py
Normal file
376
mcp-servers/git-ops/providers/base.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
Base provider interface for git hosting platforms.
|
||||
|
||||
Defines the abstract interface that all git providers must implement.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from models import (
|
||||
CreatePRResult,
|
||||
GetPRResult,
|
||||
ListPRsResult,
|
||||
MergePRResult,
|
||||
MergeStrategy,
|
||||
PRState,
|
||||
UpdatePRResult,
|
||||
)
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""
|
||||
Abstract base class for git hosting providers.
|
||||
|
||||
All providers (Gitea, GitHub, GitLab) must implement this interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Return the provider name (e.g., 'gitea', 'github')."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def is_connected(self) -> bool:
|
||||
"""Check if the provider is connected and authenticated."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_authenticated_user(self) -> str | None:
|
||||
"""Get the username of the authenticated user."""
|
||||
...
|
||||
|
||||
# Repository operations
|
||||
|
||||
@abstractmethod
|
||||
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get repository information.
|
||||
|
||||
Args:
|
||||
owner: Repository owner/organization
|
||||
repo: Repository name
|
||||
|
||||
Returns:
|
||||
Repository info dict
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""
|
||||
Get the default branch for a repository.
|
||||
|
||||
Args:
|
||||
owner: Repository owner/organization
|
||||
repo: Repository name
|
||||
|
||||
Returns:
|
||||
Default branch name
|
||||
"""
|
||||
...
|
||||
|
||||
# Pull Request operations
|
||||
|
||||
@abstractmethod
|
||||
async def create_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
body: str,
|
||||
source_branch: str,
|
||||
target_branch: str,
|
||||
draft: bool = False,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
reviewers: list[str] | None = None,
|
||||
) -> CreatePRResult:
|
||||
"""
|
||||
Create a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
title: PR title
|
||||
body: PR description
|
||||
source_branch: Source branch name
|
||||
target_branch: Target branch name
|
||||
draft: Whether to create as draft
|
||||
labels: Labels to add
|
||||
assignees: Users to assign
|
||||
reviewers: Users to request review from
|
||||
|
||||
Returns:
|
||||
CreatePRResult with PR number and URL
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
|
||||
"""
|
||||
Get a pull request by number.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
|
||||
Returns:
|
||||
GetPRResult with PR details
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_prs(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: PRState | None = None,
|
||||
author: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListPRsResult:
|
||||
"""
|
||||
List pull requests.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
state: Filter by state (open, closed, merged)
|
||||
author: Filter by author
|
||||
limit: Maximum PRs to return
|
||||
|
||||
Returns:
|
||||
ListPRsResult with list of PRs
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def merge_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
|
||||
commit_message: str | None = None,
|
||||
delete_branch: bool = True,
|
||||
) -> MergePRResult:
|
||||
"""
|
||||
Merge a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
merge_strategy: Merge strategy to use
|
||||
commit_message: Custom merge commit message
|
||||
delete_branch: Whether to delete source branch
|
||||
|
||||
Returns:
|
||||
MergePRResult with merge status
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
title: str | None = None,
|
||||
body: str | None = None,
|
||||
state: PRState | None = None,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
) -> UpdatePRResult:
|
||||
"""
|
||||
Update a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
title: New title
|
||||
body: New description
|
||||
state: New state (open, closed)
|
||||
labels: Replace labels
|
||||
assignees: Replace assignees
|
||||
|
||||
Returns:
|
||||
UpdatePRResult with updated PR info
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def close_pr(self, owner: str, repo: str, pr_number: int) -> UpdatePRResult:
|
||||
"""
|
||||
Close a pull request without merging.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
|
||||
Returns:
|
||||
UpdatePRResult with updated PR info
|
||||
"""
|
||||
...
|
||||
|
||||
# Branch operations via API (for operations that need to bypass local git)
|
||||
|
||||
@abstractmethod
|
||||
async def delete_remote_branch(self, owner: str, repo: str, branch: str) -> bool:
|
||||
"""
|
||||
Delete a remote branch via API.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
branch: Branch name to delete
|
||||
|
||||
Returns:
|
||||
True if deleted, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_branch(
|
||||
self, owner: str, repo: str, branch: str
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get branch information via API.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
branch: Branch name
|
||||
|
||||
Returns:
|
||||
Branch info dict or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
# Comment operations
|
||||
|
||||
@abstractmethod
|
||||
async def add_pr_comment(
|
||||
self, owner: str, repo: str, pr_number: int, body: str
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Add a comment to a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
body: Comment body
|
||||
|
||||
Returns:
|
||||
Created comment info
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_pr_comments(
|
||||
self, owner: str, repo: str, pr_number: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List comments on a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
|
||||
Returns:
|
||||
List of comments
|
||||
"""
|
||||
...
|
||||
|
||||
# Label operations
|
||||
|
||||
@abstractmethod
|
||||
async def add_labels(
|
||||
self, owner: str, repo: str, pr_number: int, labels: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Add labels to a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
labels: Labels to add
|
||||
|
||||
Returns:
|
||||
Updated list of labels
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def remove_label(
|
||||
self, owner: str, repo: str, pr_number: int, label: str
|
||||
) -> list[str]:
|
||||
"""
|
||||
Remove a label from a pull request.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
label: Label to remove
|
||||
|
||||
Returns:
|
||||
Updated list of labels
|
||||
"""
|
||||
...
|
||||
|
||||
# Reviewer operations
|
||||
|
||||
@abstractmethod
|
||||
async def request_review(
|
||||
self, owner: str, repo: str, pr_number: int, reviewers: list[str]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Request review from users.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
reviewers: Usernames to request review from
|
||||
|
||||
Returns:
|
||||
List of reviewers requested
|
||||
"""
|
||||
...
|
||||
|
||||
# Utility methods
|
||||
|
||||
def parse_repo_url(self, repo_url: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse repository URL to extract owner and repo name.
|
||||
|
||||
Args:
|
||||
repo_url: Repository URL (HTTPS or SSH)
|
||||
|
||||
Returns:
|
||||
Tuple of (owner, repo)
|
||||
|
||||
Raises:
|
||||
ValueError: If URL cannot be parsed
|
||||
"""
|
||||
import re
|
||||
|
||||
# Handle SSH URLs: git@host:owner/repo.git
|
||||
ssh_match = re.match(r"git@[^:]+:([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
|
||||
if ssh_match:
|
||||
return ssh_match.group(1), ssh_match.group(2)
|
||||
|
||||
# Handle HTTPS URLs: https://host/owner/repo.git
|
||||
https_match = re.match(r"https?://[^/]+/([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
|
||||
if https_match:
|
||||
return https_match.group(1), https_match.group(2)
|
||||
|
||||
raise ValueError(f"Unable to parse repository URL: {repo_url}")
|
||||
723
mcp-servers/git-ops/providers/gitea.py
Normal file
723
mcp-servers/git-ops/providers/gitea.py
Normal file
@@ -0,0 +1,723 @@
|
||||
"""
|
||||
Gitea provider implementation.
|
||||
|
||||
Implements the BaseProvider interface for Gitea API operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
PRNotFoundError,
|
||||
)
|
||||
from models import (
|
||||
CreatePRResult,
|
||||
GetPRResult,
|
||||
ListPRsResult,
|
||||
MergePRResult,
|
||||
MergeStrategy,
|
||||
PRInfo,
|
||||
PRState,
|
||||
UpdatePRResult,
|
||||
)
|
||||
|
||||
from .base import BaseProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GiteaProvider(BaseProvider):
|
||||
"""
|
||||
Gitea API provider implementation.
|
||||
|
||||
Supports all PR operations, branch operations, and repository queries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
token: str | None = None,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Gitea provider.
|
||||
|
||||
Args:
|
||||
base_url: Gitea server URL (e.g., https://gitea.example.com)
|
||||
token: API token
|
||||
settings: Optional settings override
|
||||
"""
|
||||
self.settings = settings or get_settings()
|
||||
self.base_url = (base_url or self.settings.gitea_base_url).rstrip("/")
|
||||
self.token = token or self.settings.gitea_token
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._user: str | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "gitea"
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"token {self.token}"
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=f"{self.base_url}/api/v1",
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Make an API request.
|
||||
|
||||
Args:
|
||||
method: HTTP method
|
||||
path: API path
|
||||
**kwargs: Additional request arguments
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
APIError: On API errors
|
||||
AuthenticationError: On auth failures
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
response = await client.request(method, path, **kwargs)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise AuthenticationError("gitea", "Invalid or expired token")
|
||||
|
||||
if response.status_code == 403:
|
||||
raise AuthenticationError(
|
||||
"gitea", "Insufficient permissions for this operation"
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
|
||||
if response.status_code >= 400:
|
||||
error_msg = response.text
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = error_data.get("message", error_msg)
|
||||
except Exception:
|
||||
pass
|
||||
raise APIError("gitea", response.status_code, error_msg)
|
||||
|
||||
if response.status_code == 204:
|
||||
return None
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
raise APIError("gitea", 0, f"Request failed: {e}")
|
||||
|
||||
async def is_connected(self) -> bool:
|
||||
"""Check if connected to Gitea."""
|
||||
if not self.base_url or not self.token:
|
||||
return False
|
||||
|
||||
try:
|
||||
result = await self._request("GET", "/user")
|
||||
return result is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_authenticated_user(self) -> str | None:
|
||||
"""Get the authenticated user's username."""
|
||||
if self._user:
|
||||
return self._user
|
||||
|
||||
try:
|
||||
result = await self._request("GET", "/user")
|
||||
if result:
|
||||
self._user = result.get("login") or result.get("username")
|
||||
return self._user
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
# Repository operations
|
||||
|
||||
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
|
||||
"""Get repository information."""
|
||||
result = await self._request("GET", f"/repos/{owner}/{repo}")
|
||||
if result is None:
|
||||
raise APIError("gitea", 404, f"Repository not found: {owner}/{repo}")
|
||||
return result
|
||||
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""Get the default branch for a repository."""
|
||||
repo_info = await self.get_repo_info(owner, repo)
|
||||
return repo_info.get("default_branch", "main")
|
||||
|
||||
# Pull Request operations
|
||||
|
||||
async def create_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
body: str,
|
||||
source_branch: str,
|
||||
target_branch: str,
|
||||
draft: bool = False,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
reviewers: list[str] | None = None,
|
||||
) -> CreatePRResult:
|
||||
"""Create a pull request."""
|
||||
try:
|
||||
data: dict[str, Any] = {
|
||||
"title": title,
|
||||
"body": body,
|
||||
"head": source_branch,
|
||||
"base": target_branch,
|
||||
}
|
||||
|
||||
# Note: Gitea doesn't have draft PR support in all versions
|
||||
# Draft support was added in Gitea 1.14+
|
||||
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls",
|
||||
json=data,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return CreatePRResult(
|
||||
success=False,
|
||||
error="Failed to create pull request",
|
||||
)
|
||||
|
||||
pr_number = result["number"]
|
||||
|
||||
# Add labels if specified
|
||||
if labels:
|
||||
await self.add_labels(owner, repo, pr_number, labels)
|
||||
|
||||
# Add assignees if specified (via issue update)
|
||||
if assignees:
|
||||
await self._request(
|
||||
"PATCH",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||
json={"assignees": assignees},
|
||||
)
|
||||
|
||||
# Request reviewers if specified
|
||||
if reviewers:
|
||||
await self.request_review(owner, repo, pr_number, reviewers)
|
||||
|
||||
return CreatePRResult(
|
||||
success=True,
|
||||
pr_number=pr_number,
|
||||
pr_url=result.get("html_url"),
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return CreatePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
|
||||
"""Get a pull request by number."""
|
||||
try:
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
|
||||
|
||||
pr_info = self._parse_pr(result)
|
||||
|
||||
return GetPRResult(
|
||||
success=True,
|
||||
pr=pr_info.to_dict(),
|
||||
)
|
||||
|
||||
except PRNotFoundError:
|
||||
return GetPRResult(
|
||||
success=False,
|
||||
error=f"Pull request #{pr_number} not found",
|
||||
)
|
||||
except APIError as e:
|
||||
return GetPRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def list_prs(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: PRState | None = None,
|
||||
author: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListPRsResult:
|
||||
"""List pull requests."""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
if state:
|
||||
# Gitea uses different state names
|
||||
if state == PRState.OPEN:
|
||||
params["state"] = "open"
|
||||
elif state == PRState.CLOSED or state == PRState.MERGED:
|
||||
params["state"] = "closed"
|
||||
else:
|
||||
params["state"] = "all"
|
||||
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/pulls",
|
||||
params=params,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ListPRsResult(
|
||||
success=True,
|
||||
pull_requests=[],
|
||||
total_count=0,
|
||||
)
|
||||
|
||||
prs = []
|
||||
for pr_data in result:
|
||||
# Filter by author if specified
|
||||
if author:
|
||||
pr_author = pr_data.get("user", {}).get("login", "")
|
||||
if pr_author.lower() != author.lower():
|
||||
continue
|
||||
|
||||
# Filter merged PRs if looking specifically for merged
|
||||
if state == PRState.MERGED:
|
||||
if not pr_data.get("merged"):
|
||||
continue
|
||||
|
||||
pr_info = self._parse_pr(pr_data)
|
||||
prs.append(pr_info.to_dict())
|
||||
|
||||
return ListPRsResult(
|
||||
success=True,
|
||||
pull_requests=prs,
|
||||
total_count=len(prs),
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return ListPRsResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def merge_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
|
||||
commit_message: str | None = None,
|
||||
delete_branch: bool = True,
|
||||
) -> MergePRResult:
|
||||
"""Merge a pull request."""
|
||||
try:
|
||||
# Map merge strategy to Gitea's "Do" values
|
||||
do_map = {
|
||||
MergeStrategy.MERGE: "merge",
|
||||
MergeStrategy.SQUASH: "squash",
|
||||
MergeStrategy.REBASE: "rebase",
|
||||
}
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"Do": do_map[merge_strategy],
|
||||
"delete_branch_after_merge": delete_branch,
|
||||
}
|
||||
|
||||
if commit_message:
|
||||
data["MergeTitleField"] = commit_message.split("\n")[0]
|
||||
if "\n" in commit_message:
|
||||
data["MergeMessageField"] = "\n".join(
|
||||
commit_message.split("\n")[1:]
|
||||
)
|
||||
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
|
||||
json=data,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
# Check if PR was actually merged
|
||||
pr_result = await self.get_pr(owner, repo, pr_number)
|
||||
if pr_result.success and pr_result.pr:
|
||||
if pr_result.pr.get("state") == "merged":
|
||||
return MergePRResult(
|
||||
success=True,
|
||||
branch_deleted=delete_branch,
|
||||
)
|
||||
|
||||
return MergePRResult(
|
||||
success=False,
|
||||
error="Failed to merge pull request",
|
||||
)
|
||||
|
||||
return MergePRResult(
|
||||
success=True,
|
||||
merge_commit_sha=result.get("sha"),
|
||||
branch_deleted=delete_branch,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return MergePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def update_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
title: str | None = None,
|
||||
body: str | None = None,
|
||||
state: PRState | None = None,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
) -> UpdatePRResult:
|
||||
"""Update a pull request."""
|
||||
try:
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
if body is not None:
|
||||
data["body"] = body
|
||||
if state is not None:
|
||||
if state == PRState.OPEN:
|
||||
data["state"] = "open"
|
||||
elif state == PRState.CLOSED:
|
||||
data["state"] = "closed"
|
||||
|
||||
# Update PR if there's data
|
||||
if data:
|
||||
await self._request(
|
||||
"PATCH",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||
json=data,
|
||||
)
|
||||
|
||||
# Update labels via issue endpoint
|
||||
if labels is not None:
|
||||
# First clear existing labels
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||
)
|
||||
# Then add new labels
|
||||
if labels:
|
||||
await self.add_labels(owner, repo, pr_number, labels)
|
||||
|
||||
# Update assignees via issue endpoint
|
||||
if assignees is not None:
|
||||
await self._request(
|
||||
"PATCH",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||
json={"assignees": assignees},
|
||||
)
|
||||
|
||||
# Fetch updated PR
|
||||
result = await self.get_pr(owner, repo, pr_number)
|
||||
return UpdatePRResult(
|
||||
success=result.success,
|
||||
pr=result.pr,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return UpdatePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def close_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
) -> UpdatePRResult:
|
||||
"""Close a pull request without merging."""
|
||||
return await self.update_pr(
|
||||
owner,
|
||||
repo,
|
||||
pr_number,
|
||||
state=PRState.CLOSED,
|
||||
)
|
||||
|
||||
# Branch operations
|
||||
|
||||
async def delete_remote_branch(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
) -> bool:
|
||||
"""Delete a remote branch."""
|
||||
try:
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/branches/{branch}",
|
||||
)
|
||||
return True
|
||||
except APIError:
|
||||
return False
|
||||
|
||||
async def get_branch(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get branch information."""
|
||||
return await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/branches/{branch}",
|
||||
)
|
||||
|
||||
# Comment operations
|
||||
|
||||
async def add_pr_comment(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
body: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Add a comment to a pull request."""
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||
json={"body": body},
|
||||
)
|
||||
return result or {}
|
||||
|
||||
async def list_pr_comments(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List comments on a pull request."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||
)
|
||||
return result or []
|
||||
|
||||
# Label operations
|
||||
|
||||
async def add_labels(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
labels: list[str],
|
||||
) -> list[str]:
|
||||
"""Add labels to a pull request."""
|
||||
# First, get or create label IDs
|
||||
label_ids = []
|
||||
for label_name in labels:
|
||||
label_id = await self._get_or_create_label(owner, repo, label_name)
|
||||
if label_id:
|
||||
label_ids.append(label_id)
|
||||
|
||||
if label_ids:
|
||||
await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||
json={"labels": label_ids},
|
||||
)
|
||||
|
||||
# Return current labels
|
||||
issue = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||
)
|
||||
if issue:
|
||||
return [lbl["name"] for lbl in issue.get("labels", [])]
|
||||
return labels
|
||||
|
||||
async def remove_label(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
label: str,
|
||||
) -> list[str]:
|
||||
"""Remove a label from a pull request."""
|
||||
# Get label ID
|
||||
label_info = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/labels?name={label}",
|
||||
)
|
||||
|
||||
if label_info and len(label_info) > 0:
|
||||
label_id = label_info[0]["id"]
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label_id}",
|
||||
)
|
||||
|
||||
# Return remaining labels
|
||||
issue = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||
)
|
||||
if issue:
|
||||
return [lbl["name"] for lbl in issue.get("labels", [])]
|
||||
return []
|
||||
|
||||
async def _get_or_create_label(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
label_name: str,
|
||||
) -> int | None:
|
||||
"""Get or create a label and return its ID."""
|
||||
# Try to find existing label
|
||||
labels = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/labels",
|
||||
)
|
||||
|
||||
if labels:
|
||||
for label in labels:
|
||||
if label["name"].lower() == label_name.lower():
|
||||
return label["id"]
|
||||
|
||||
# Create new label with default color
|
||||
try:
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/labels",
|
||||
json={
|
||||
"name": label_name,
|
||||
"color": "#3B82F6", # Default blue
|
||||
},
|
||||
)
|
||||
if result:
|
||||
return result["id"]
|
||||
except APIError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# Reviewer operations
|
||||
|
||||
async def request_review(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
reviewers: list[str],
|
||||
) -> list[str]:
|
||||
"""Request review from users."""
|
||||
await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
|
||||
json={"reviewers": reviewers},
|
||||
)
|
||||
return reviewers
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
|
||||
"""Parse PR API response into PRInfo."""
|
||||
# Parse dates
|
||||
created_at = self._parse_datetime(data.get("created_at"))
|
||||
updated_at = self._parse_datetime(data.get("updated_at"))
|
||||
merged_at = self._parse_datetime(data.get("merged_at"))
|
||||
closed_at = self._parse_datetime(data.get("closed_at"))
|
||||
|
||||
# Determine state
|
||||
if data.get("merged"):
|
||||
state = PRState.MERGED
|
||||
elif data.get("state") == "closed":
|
||||
state = PRState.CLOSED
|
||||
else:
|
||||
state = PRState.OPEN
|
||||
|
||||
# Extract labels
|
||||
labels = [lbl["name"] for lbl in data.get("labels", [])]
|
||||
|
||||
# Extract assignees
|
||||
assignees = [a["login"] for a in data.get("assignees", [])]
|
||||
|
||||
# Extract reviewers
|
||||
reviewers = []
|
||||
if "requested_reviewers" in data:
|
||||
reviewers = [r["login"] for r in data["requested_reviewers"]]
|
||||
|
||||
return PRInfo(
|
||||
number=data["number"],
|
||||
title=data["title"],
|
||||
body=data.get("body", ""),
|
||||
state=state,
|
||||
source_branch=data.get("head", {}).get("ref", ""),
|
||||
target_branch=data.get("base", {}).get("ref", ""),
|
||||
author=data.get("user", {}).get("login", ""),
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
merged_at=merged_at,
|
||||
closed_at=closed_at,
|
||||
url=data.get("html_url"),
|
||||
labels=labels,
|
||||
assignees=assignees,
|
||||
reviewers=reviewers,
|
||||
mergeable=data.get("mergeable"),
|
||||
draft=data.get("draft", False),
|
||||
)
|
||||
|
||||
def _parse_datetime(self, value: str | None) -> datetime:
|
||||
"""Parse datetime string from API."""
|
||||
if not value:
|
||||
return datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# Handle Gitea's datetime format
|
||||
if value.endswith("Z"):
|
||||
value = value[:-1] + "+00:00"
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
return datetime.now(UTC)
|
||||
675
mcp-servers/git-ops/providers/github.py
Normal file
675
mcp-servers/git-ops/providers/github.py
Normal file
@@ -0,0 +1,675 @@
|
||||
"""
|
||||
GitHub provider implementation.
|
||||
|
||||
Implements the BaseProvider interface for GitHub API operations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
PRNotFoundError,
|
||||
)
|
||||
from models import (
|
||||
CreatePRResult,
|
||||
GetPRResult,
|
||||
ListPRsResult,
|
||||
MergePRResult,
|
||||
MergeStrategy,
|
||||
PRInfo,
|
||||
PRState,
|
||||
UpdatePRResult,
|
||||
)
|
||||
|
||||
from .base import BaseProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GitHubProvider(BaseProvider):
|
||||
"""
|
||||
GitHub API provider implementation.
|
||||
|
||||
Supports all PR operations, branch operations, and repository queries.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token: str | None = None,
|
||||
settings: Settings | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GitHub provider.
|
||||
|
||||
Args:
|
||||
token: GitHub personal access token or fine-grained token
|
||||
settings: Optional settings override
|
||||
"""
|
||||
self.settings = settings or get_settings()
|
||||
self.token = token or self.settings.github_token
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
self._user: str | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the provider name."""
|
||||
return "github"
|
||||
|
||||
async def _get_client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None:
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
if self.token:
|
||||
headers["Authorization"] = f"Bearer {self.token}"
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url="https://api.github.com",
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
return self._client
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""
|
||||
Make an API request.
|
||||
|
||||
Args:
|
||||
method: HTTP method
|
||||
path: API path
|
||||
**kwargs: Additional request arguments
|
||||
|
||||
Returns:
|
||||
Parsed JSON response
|
||||
|
||||
Raises:
|
||||
APIError: On API errors
|
||||
AuthenticationError: On auth failures
|
||||
"""
|
||||
client = await self._get_client()
|
||||
|
||||
try:
|
||||
response = await client.request(method, path, **kwargs)
|
||||
|
||||
if response.status_code == 401:
|
||||
raise AuthenticationError("github", "Invalid or expired token")
|
||||
|
||||
if response.status_code == 403:
|
||||
# Check for rate limiting
|
||||
if "rate limit" in response.text.lower():
|
||||
raise APIError("github", 403, "GitHub API rate limit exceeded")
|
||||
raise AuthenticationError(
|
||||
"github", "Insufficient permissions for this operation"
|
||||
)
|
||||
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
|
||||
if response.status_code >= 400:
|
||||
error_msg = response.text
|
||||
try:
|
||||
error_data = response.json()
|
||||
error_msg = error_data.get("message", error_msg)
|
||||
except Exception:
|
||||
pass
|
||||
raise APIError("github", response.status_code, error_msg)
|
||||
|
||||
if response.status_code == 204:
|
||||
return None
|
||||
|
||||
return response.json()
|
||||
|
||||
except httpx.RequestError as e:
|
||||
raise APIError("github", 0, f"Request failed: {e}")
|
||||
|
||||
async def is_connected(self) -> bool:
|
||||
"""Check if connected to GitHub."""
|
||||
if not self.token:
|
||||
return False
|
||||
|
||||
try:
|
||||
result = await self._request("GET", "/user")
|
||||
return result is not None
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_authenticated_user(self) -> str | None:
|
||||
"""Get the authenticated user's username."""
|
||||
if self._user:
|
||||
return self._user
|
||||
|
||||
try:
|
||||
result = await self._request("GET", "/user")
|
||||
if result:
|
||||
self._user = result.get("login")
|
||||
return self._user
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
# Repository operations
|
||||
|
||||
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
|
||||
"""Get repository information."""
|
||||
result = await self._request("GET", f"/repos/{owner}/{repo}")
|
||||
if result is None:
|
||||
raise APIError("github", 404, f"Repository not found: {owner}/{repo}")
|
||||
return result
|
||||
|
||||
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||
"""Get the default branch for a repository."""
|
||||
repo_info = await self.get_repo_info(owner, repo)
|
||||
return repo_info.get("default_branch", "main")
|
||||
|
||||
# Pull Request operations
|
||||
|
||||
async def create_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
title: str,
|
||||
body: str,
|
||||
source_branch: str,
|
||||
target_branch: str,
|
||||
draft: bool = False,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
reviewers: list[str] | None = None,
|
||||
) -> CreatePRResult:
|
||||
"""Create a pull request."""
|
||||
try:
|
||||
data: dict[str, Any] = {
|
||||
"title": title,
|
||||
"body": body,
|
||||
"head": source_branch,
|
||||
"base": target_branch,
|
||||
"draft": draft,
|
||||
}
|
||||
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls",
|
||||
json=data,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return CreatePRResult(
|
||||
success=False,
|
||||
error="Failed to create pull request",
|
||||
)
|
||||
|
||||
pr_number = result["number"]
|
||||
|
||||
# Add labels if specified
|
||||
if labels:
|
||||
await self.add_labels(owner, repo, pr_number, labels)
|
||||
|
||||
# Add assignees if specified
|
||||
if assignees:
|
||||
await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
||||
json={"assignees": assignees},
|
||||
)
|
||||
|
||||
# Request reviewers if specified
|
||||
if reviewers:
|
||||
await self.request_review(owner, repo, pr_number, reviewers)
|
||||
|
||||
return CreatePRResult(
|
||||
success=True,
|
||||
pr_number=pr_number,
|
||||
pr_url=result.get("html_url"),
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return CreatePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
|
||||
"""Get a pull request by number."""
|
||||
try:
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
|
||||
|
||||
pr_info = self._parse_pr(result)
|
||||
|
||||
return GetPRResult(
|
||||
success=True,
|
||||
pr=pr_info.to_dict(),
|
||||
)
|
||||
|
||||
except PRNotFoundError:
|
||||
return GetPRResult(
|
||||
success=False,
|
||||
error=f"Pull request #{pr_number} not found",
|
||||
)
|
||||
except APIError as e:
|
||||
return GetPRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def list_prs(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
state: PRState | None = None,
|
||||
author: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListPRsResult:
|
||||
"""List pull requests."""
|
||||
try:
|
||||
params: dict[str, Any] = {
|
||||
"per_page": min(limit, 100), # GitHub max is 100
|
||||
}
|
||||
|
||||
if state:
|
||||
# GitHub uses 'state' for open/closed only
|
||||
# Merged PRs are closed PRs with merged_at set
|
||||
if state == PRState.OPEN:
|
||||
params["state"] = "open"
|
||||
elif state in (PRState.CLOSED, PRState.MERGED):
|
||||
params["state"] = "closed"
|
||||
else:
|
||||
params["state"] = "all"
|
||||
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/pulls",
|
||||
params=params,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ListPRsResult(
|
||||
success=True,
|
||||
pull_requests=[],
|
||||
total_count=0,
|
||||
)
|
||||
|
||||
prs = []
|
||||
for pr_data in result:
|
||||
# Filter by author if specified
|
||||
if author:
|
||||
pr_author = pr_data.get("user", {}).get("login", "")
|
||||
if pr_author.lower() != author.lower():
|
||||
continue
|
||||
|
||||
# Filter merged PRs if looking specifically for merged
|
||||
if state == PRState.MERGED:
|
||||
if not pr_data.get("merged_at"):
|
||||
continue
|
||||
|
||||
pr_info = self._parse_pr(pr_data)
|
||||
prs.append(pr_info.to_dict())
|
||||
|
||||
return ListPRsResult(
|
||||
success=True,
|
||||
pull_requests=prs,
|
||||
total_count=len(prs),
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return ListPRsResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def merge_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
|
||||
commit_message: str | None = None,
|
||||
delete_branch: bool = True,
|
||||
) -> MergePRResult:
|
||||
"""Merge a pull request."""
|
||||
try:
|
||||
# Map merge strategy to GitHub's merge_method values
|
||||
method_map = {
|
||||
MergeStrategy.MERGE: "merge",
|
||||
MergeStrategy.SQUASH: "squash",
|
||||
MergeStrategy.REBASE: "rebase",
|
||||
}
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"merge_method": method_map[merge_strategy],
|
||||
}
|
||||
|
||||
if commit_message:
|
||||
# For squash, commit_title and commit_message
|
||||
# For merge, commit_title and commit_message
|
||||
parts = commit_message.split("\n", 1)
|
||||
data["commit_title"] = parts[0]
|
||||
if len(parts) > 1:
|
||||
data["commit_message"] = parts[1]
|
||||
|
||||
result = await self._request(
|
||||
"PUT",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
|
||||
json=data,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return MergePRResult(
|
||||
success=False,
|
||||
error="Failed to merge pull request",
|
||||
)
|
||||
|
||||
branch_deleted = False
|
||||
# Delete branch if requested
|
||||
if delete_branch and result.get("merged"):
|
||||
# Get PR to find the branch name
|
||||
pr_result = await self.get_pr(owner, repo, pr_number)
|
||||
if pr_result.success and pr_result.pr:
|
||||
source_branch = pr_result.pr.get("source_branch")
|
||||
if source_branch:
|
||||
branch_deleted = await self.delete_remote_branch(
|
||||
owner, repo, source_branch
|
||||
)
|
||||
|
||||
return MergePRResult(
|
||||
success=True,
|
||||
merge_commit_sha=result.get("sha"),
|
||||
branch_deleted=branch_deleted,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return MergePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def update_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
title: str | None = None,
|
||||
body: str | None = None,
|
||||
state: PRState | None = None,
|
||||
labels: list[str] | None = None,
|
||||
assignees: list[str] | None = None,
|
||||
) -> UpdatePRResult:
|
||||
"""Update a pull request."""
|
||||
try:
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
if body is not None:
|
||||
data["body"] = body
|
||||
if state is not None:
|
||||
if state == PRState.OPEN:
|
||||
data["state"] = "open"
|
||||
elif state == PRState.CLOSED:
|
||||
data["state"] = "closed"
|
||||
|
||||
# Update PR if there's data
|
||||
if data:
|
||||
await self._request(
|
||||
"PATCH",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||
json=data,
|
||||
)
|
||||
|
||||
# Update labels via issue endpoint
|
||||
if labels is not None:
|
||||
await self._request(
|
||||
"PUT",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||
json={"labels": labels},
|
||||
)
|
||||
|
||||
# Update assignees via issue endpoint
|
||||
if assignees is not None:
|
||||
# First remove all assignees
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
||||
json={"assignees": []},
|
||||
)
|
||||
# Then add new ones
|
||||
if assignees:
|
||||
await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
||||
json={"assignees": assignees},
|
||||
)
|
||||
|
||||
# Fetch updated PR
|
||||
result = await self.get_pr(owner, repo, pr_number)
|
||||
return UpdatePRResult(
|
||||
success=result.success,
|
||||
pr=result.pr,
|
||||
error=result.error,
|
||||
)
|
||||
|
||||
except APIError as e:
|
||||
return UpdatePRResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def close_pr(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
) -> UpdatePRResult:
|
||||
"""Close a pull request without merging."""
|
||||
return await self.update_pr(
|
||||
owner,
|
||||
repo,
|
||||
pr_number,
|
||||
state=PRState.CLOSED,
|
||||
)
|
||||
|
||||
# Branch operations
|
||||
|
||||
async def delete_remote_branch(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
) -> bool:
|
||||
"""Delete a remote branch."""
|
||||
try:
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/git/refs/heads/{branch}",
|
||||
)
|
||||
return True
|
||||
except APIError:
|
||||
return False
|
||||
|
||||
async def get_branch(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Get branch information."""
|
||||
return await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/branches/{branch}",
|
||||
)
|
||||
|
||||
# Comment operations
|
||||
|
||||
async def add_pr_comment(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
body: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Add a comment to a pull request."""
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||
json={"body": body},
|
||||
)
|
||||
return result or {}
|
||||
|
||||
async def list_pr_comments(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""List comments on a pull request."""
|
||||
result = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||
)
|
||||
return result or []
|
||||
|
||||
# Label operations
|
||||
|
||||
async def add_labels(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
labels: list[str],
|
||||
) -> list[str]:
|
||||
"""Add labels to a pull request."""
|
||||
# GitHub creates labels automatically if they don't exist (unlike Gitea)
|
||||
result = await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||
json={"labels": labels},
|
||||
)
|
||||
|
||||
if result:
|
||||
return [lbl["name"] for lbl in result]
|
||||
return labels
|
||||
|
||||
async def remove_label(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
label: str,
|
||||
) -> list[str]:
|
||||
"""Remove a label from a pull request."""
|
||||
await self._request(
|
||||
"DELETE",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label}",
|
||||
)
|
||||
|
||||
# Return remaining labels
|
||||
issue = await self._request(
|
||||
"GET",
|
||||
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||
)
|
||||
if issue:
|
||||
return [lbl["name"] for lbl in issue.get("labels", [])]
|
||||
return []
|
||||
|
||||
# Reviewer operations
|
||||
|
||||
async def request_review(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
reviewers: list[str],
|
||||
) -> list[str]:
|
||||
"""Request review from users."""
|
||||
await self._request(
|
||||
"POST",
|
||||
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
|
||||
json={"reviewers": reviewers},
|
||||
)
|
||||
return reviewers
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
|
||||
"""Parse PR API response into PRInfo."""
|
||||
# Parse dates
|
||||
created_at = self._parse_datetime(data.get("created_at"))
|
||||
updated_at = self._parse_datetime(data.get("updated_at"))
|
||||
merged_at = self._parse_datetime(data.get("merged_at"))
|
||||
closed_at = self._parse_datetime(data.get("closed_at"))
|
||||
|
||||
# Determine state
|
||||
if data.get("merged_at"):
|
||||
state = PRState.MERGED
|
||||
elif data.get("state") == "closed":
|
||||
state = PRState.CLOSED
|
||||
else:
|
||||
state = PRState.OPEN
|
||||
|
||||
# Extract labels
|
||||
labels = [lbl["name"] for lbl in data.get("labels", [])]
|
||||
|
||||
# Extract assignees
|
||||
assignees = [a["login"] for a in data.get("assignees", [])]
|
||||
|
||||
# Extract reviewers
|
||||
reviewers = []
|
||||
if "requested_reviewers" in data:
|
||||
reviewers = [r["login"] for r in data["requested_reviewers"]]
|
||||
|
||||
return PRInfo(
|
||||
number=data["number"],
|
||||
title=data["title"],
|
||||
body=data.get("body", "") or "",
|
||||
state=state,
|
||||
source_branch=data.get("head", {}).get("ref", ""),
|
||||
target_branch=data.get("base", {}).get("ref", ""),
|
||||
author=data.get("user", {}).get("login", ""),
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
merged_at=merged_at,
|
||||
closed_at=closed_at,
|
||||
url=data.get("html_url"),
|
||||
labels=labels,
|
||||
assignees=assignees,
|
||||
reviewers=reviewers,
|
||||
mergeable=data.get("mergeable"),
|
||||
draft=data.get("draft", False),
|
||||
)
|
||||
|
||||
def _parse_datetime(self, value: str | None) -> datetime:
|
||||
"""Parse datetime string from API."""
|
||||
if not value:
|
||||
return datetime.now(UTC)
|
||||
|
||||
try:
|
||||
# GitHub uses ISO 8601 format with Z suffix
|
||||
if value.endswith("Z"):
|
||||
value = value[:-1] + "+00:00"
|
||||
return datetime.fromisoformat(value)
|
||||
except ValueError:
|
||||
return datetime.now(UTC)
|
||||
120
mcp-servers/git-ops/pyproject.toml
Normal file
120
mcp-servers/git-ops/pyproject.toml
Normal file
@@ -0,0 +1,120 @@
|
||||
[project]
|
||||
name = "syndarix-mcp-git-ops"
|
||||
version = "0.1.0"
|
||||
description = "Syndarix Git Operations MCP Server - Repository management, branching, commits, and PR workflows"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastmcp>=2.0.0",
|
||||
"gitpython>=3.1.0",
|
||||
"httpx>=0.27.0",
|
||||
"redis>=5.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic-settings>=2.0.0",
|
||||
"uvicorn>=0.30.0",
|
||||
"fastapi>=0.115.0",
|
||||
"filelock>=3.15.0",
|
||||
"aiofiles>=24.1.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-cov>=5.0.0",
|
||||
"fakeredis>=2.25.0",
|
||||
"ruff>=0.8.0",
|
||||
"mypy>=1.11.0",
|
||||
"respx>=0.21.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
git-ops = "server:main"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["."]
|
||||
exclude = ["tests/", "*.md", "Dockerfile"]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
include = ["*.py", "pyproject.toml"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 88
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
"ARG", # flake8-unused-arguments
|
||||
"SIM", # flake8-simplify
|
||||
"S", # flake8-bandit (security)
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
"B008", # do not perform function calls in argument defaults
|
||||
"B904", # raise from in except (too noisy)
|
||||
"S104", # possible binding to all interfaces
|
||||
"S110", # try-except-pass (intentional for optional operations)
|
||||
"S603", # subprocess without shell=True (safe usage in git wrapper)
|
||||
"S607", # starting a process with a partial path (git CLI)
|
||||
"ARG002", # unused method arguments (for API compatibility)
|
||||
"SIM102", # nested if statements (sometimes more readable)
|
||||
"SIM105", # contextlib.suppress (sometimes more readable)
|
||||
"SIM108", # ternary operator (sometimes more readable)
|
||||
"SIM118", # dict.keys() (explicit is fine)
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["config", "models", "exceptions", "git_wrapper", "workspace", "providers"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**/*.py" = ["S101", "ARG001", "S105", "S106", "S108", "F841", "B007"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
testpaths = ["tests"]
|
||||
addopts = "-v --tb=short"
|
||||
filterwarnings = [
|
||||
"ignore::DeprecationWarning",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["."]
|
||||
omit = ["tests/*", "conftest.py"]
|
||||
branch = true
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"raise NotImplementedError",
|
||||
"if TYPE_CHECKING:",
|
||||
"if __name__ == .__main__.:",
|
||||
]
|
||||
fail_under = 78
|
||||
show_missing = true
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
warn_return_any = false
|
||||
warn_unused_ignores = false
|
||||
disallow_untyped_defs = true
|
||||
ignore_missing_imports = true
|
||||
plugins = ["pydantic.mypy"]
|
||||
files = ["server.py", "config.py", "models.py", "exceptions.py", "git_wrapper.py", "workspace.py", "providers/"]
|
||||
exclude = ["tests/"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "tests.*"
|
||||
disallow_untyped_defs = false
|
||||
ignore_errors = true
|
||||
1674
mcp-servers/git-ops/server.py
Normal file
1674
mcp-servers/git-ops/server.py
Normal file
File diff suppressed because it is too large
Load Diff
1
mcp-servers/git-ops/tests/__init__.py
Normal file
1
mcp-servers/git-ops/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for Git Operations MCP Server."""
|
||||
297
mcp-servers/git-ops/tests/conftest.py
Normal file
297
mcp-servers/git-ops/tests/conftest.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Test configuration and fixtures for Git Operations MCP Server.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from git import Repo as GitRepo
|
||||
|
||||
# Set test environment
|
||||
os.environ["IS_TEST"] = "true"
|
||||
os.environ["GIT_OPS_WORKSPACE_BASE_PATH"] = "/tmp/test-workspaces"
|
||||
os.environ["GIT_OPS_GITEA_BASE_URL"] = "https://gitea.test.com"
|
||||
os.environ["GIT_OPS_GITEA_TOKEN"] = "test-token"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def reset_settings_session():
|
||||
"""Reset settings at start and end of test session."""
|
||||
from config import reset_settings
|
||||
|
||||
reset_settings()
|
||||
yield
|
||||
reset_settings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_settings():
|
||||
"""Reset settings before each test that needs it."""
|
||||
from config import reset_settings
|
||||
|
||||
reset_settings()
|
||||
yield
|
||||
reset_settings()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_settings():
|
||||
"""Get test settings."""
|
||||
from config import Settings
|
||||
|
||||
return Settings(
|
||||
workspace_base_path=Path("/tmp/test-workspaces"),
|
||||
gitea_base_url="https://gitea.test.com",
|
||||
gitea_token="test-token",
|
||||
github_token="github-test-token",
|
||||
git_author_name="Test Agent",
|
||||
git_author_email="test@syndarix.ai",
|
||||
enable_force_push=False,
|
||||
debug=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir() -> Iterator[Path]:
|
||||
"""Create a temporary directory for tests."""
|
||||
temp_path = Path(tempfile.mkdtemp())
|
||||
yield temp_path
|
||||
if temp_path.exists():
|
||||
shutil.rmtree(temp_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_workspace(temp_dir: Path) -> Path:
|
||||
"""Create a temporary workspace directory."""
|
||||
workspace = temp_dir / "workspace"
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
return workspace
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_repo(temp_workspace: Path) -> GitRepo:
|
||||
"""Create a git repository in the temp workspace."""
|
||||
# Initialize with main branch (Git 2.28+)
|
||||
repo = GitRepo.init(temp_workspace, initial_branch="main")
|
||||
|
||||
# Configure git
|
||||
with repo.config_writer() as cw:
|
||||
cw.set_value("user", "name", "Test User")
|
||||
cw.set_value("user", "email", "test@example.com")
|
||||
|
||||
# Create initial commit
|
||||
test_file = temp_workspace / "README.md"
|
||||
test_file.write_text("# Test Repository\n")
|
||||
repo.index.add(["README.md"])
|
||||
repo.index.commit("Initial commit")
|
||||
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_repo_with_remote(git_repo: GitRepo, temp_dir: Path) -> tuple[GitRepo, GitRepo]:
|
||||
"""Create a git repository with a 'remote' (bare repo)."""
|
||||
# Create bare repo as remote
|
||||
remote_path = temp_dir / "remote.git"
|
||||
remote_repo = GitRepo.init(remote_path, bare=True)
|
||||
|
||||
# Add remote to main repo
|
||||
git_repo.create_remote("origin", str(remote_path))
|
||||
|
||||
# Push initial commit
|
||||
git_repo.remotes.origin.push("main:main")
|
||||
|
||||
# Set up tracking
|
||||
git_repo.heads.main.set_tracking_branch(git_repo.remotes.origin.refs.main)
|
||||
|
||||
return git_repo, remote_repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_manager(temp_dir: Path, test_settings):
|
||||
"""Create a WorkspaceManager with test settings."""
|
||||
from workspace import WorkspaceManager
|
||||
|
||||
test_settings.workspace_base_path = temp_dir / "workspaces"
|
||||
return WorkspaceManager(test_settings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_wrapper(temp_workspace: Path, test_settings):
|
||||
"""Create a GitWrapper for the temp workspace."""
|
||||
from git_wrapper import GitWrapper
|
||||
|
||||
return GitWrapper(temp_workspace, test_settings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git_wrapper_with_repo(git_repo: GitRepo, test_settings):
|
||||
"""Create a GitWrapper for a repo that's already initialized."""
|
||||
from git_wrapper import GitWrapper
|
||||
|
||||
return GitWrapper(Path(git_repo.working_dir), test_settings)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_gitea_provider():
|
||||
"""Create a mock Gitea provider."""
|
||||
provider = AsyncMock()
|
||||
provider.name = "gitea"
|
||||
provider.is_connected = AsyncMock(return_value=True)
|
||||
provider.get_authenticated_user = AsyncMock(return_value="test-user")
|
||||
provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client():
|
||||
"""Create a mock httpx client for provider tests."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = MagicMock(return_value={})
|
||||
mock_response.text = ""
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.patch = AsyncMock(return_value=mock_response)
|
||||
mock_client.delete = AsyncMock(return_value=mock_response)
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def gitea_provider(test_settings, mock_httpx_client):
|
||||
"""Create a GiteaProvider with mocked HTTP client."""
|
||||
from providers.gitea import GiteaProvider
|
||||
|
||||
provider = GiteaProvider(
|
||||
base_url=test_settings.gitea_base_url,
|
||||
token=test_settings.gitea_token,
|
||||
settings=test_settings,
|
||||
)
|
||||
provider._client = mock_httpx_client
|
||||
|
||||
yield provider
|
||||
|
||||
await provider.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pr_data():
|
||||
"""Sample PR data from Gitea API."""
|
||||
return {
|
||||
"number": 42,
|
||||
"title": "Test PR",
|
||||
"body": "This is a test pull request",
|
||||
"state": "open",
|
||||
"head": {"ref": "feature-branch"},
|
||||
"base": {"ref": "main"},
|
||||
"user": {"login": "test-user"},
|
||||
"created_at": "2024-01-15T10:00:00Z",
|
||||
"updated_at": "2024-01-15T12:00:00Z",
|
||||
"merged_at": None,
|
||||
"closed_at": None,
|
||||
"html_url": "https://gitea.test.com/owner/repo/pull/42",
|
||||
"labels": [{"name": "enhancement"}],
|
||||
"assignees": [{"login": "assignee1"}],
|
||||
"requested_reviewers": [{"login": "reviewer1"}],
|
||||
"mergeable": True,
|
||||
"draft": False,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_commit_data():
|
||||
"""Sample commit data."""
|
||||
return {
|
||||
"sha": "abc123def456",
|
||||
"short_sha": "abc123d",
|
||||
"message": "Test commit message",
|
||||
"author": {
|
||||
"name": "Test Author",
|
||||
"email": "author@test.com",
|
||||
"date": "2024-01-15T10:00:00Z",
|
||||
},
|
||||
"committer": {
|
||||
"name": "Test Committer",
|
||||
"email": "committer@test.com",
|
||||
"date": "2024-01-15T10:00:00Z",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_fastapi_app():
|
||||
"""Create a test FastAPI app."""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# Async fixtures
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_workspace_manager(temp_dir: Path, test_settings) -> AsyncIterator:
|
||||
"""Async fixture for workspace manager."""
|
||||
from workspace import WorkspaceManager
|
||||
|
||||
test_settings.workspace_base_path = temp_dir / "workspaces"
|
||||
manager = WorkspaceManager(test_settings)
|
||||
yield manager
|
||||
|
||||
|
||||
# Test data fixtures
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_project_id() -> str:
|
||||
"""Valid project ID for tests."""
|
||||
return "test-project-123"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_agent_id() -> str:
|
||||
"""Valid agent ID for tests."""
|
||||
return "agent-456"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_ids() -> list[str]:
|
||||
"""Invalid IDs for validation tests."""
|
||||
return [
|
||||
"",
|
||||
" ",
|
||||
"a" * 200, # Too long
|
||||
"test@invalid", # Invalid character
|
||||
"test!invalid",
|
||||
"../path/traversal",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_repo_url() -> str:
|
||||
"""Sample repository URL."""
|
||||
return "https://gitea.test.com/owner/repo.git"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_ssh_repo_url() -> str:
|
||||
"""Sample SSH repository URL."""
|
||||
return "git@gitea.test.com:owner/repo.git"
|
||||
440
mcp-servers/git-ops/tests/test_api_endpoints.py
Normal file
440
mcp-servers/git-ops/tests/test_api_endpoints.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
Tests for FastAPI endpoints.
|
||||
|
||||
Tests health check and MCP JSON-RPC endpoints.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_no_providers(self):
|
||||
"""Test health check when no providers configured."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
with (
|
||||
patch("server._settings", MagicMock()),
|
||||
patch("server._workspace_manager", None),
|
||||
patch("server._gitea_provider", None),
|
||||
patch("server._github_provider", None),
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] in ["healthy", "degraded"]
|
||||
assert data["service"] == "git-ops"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_with_gitea_connected(self):
|
||||
"""Test health check with Gitea provider connected."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
mock_gitea = AsyncMock()
|
||||
mock_gitea.is_connected = AsyncMock(return_value=True)
|
||||
mock_gitea.get_authenticated_user = AsyncMock(return_value="test-user")
|
||||
|
||||
with (
|
||||
patch("server._settings", MagicMock()),
|
||||
patch("server._workspace_manager", None),
|
||||
patch("server._gitea_provider", mock_gitea),
|
||||
patch("server._github_provider", None),
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "gitea" in data["dependencies"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_with_gitea_not_connected(self):
|
||||
"""Test health check when Gitea is not connected."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
mock_gitea = AsyncMock()
|
||||
mock_gitea.is_connected = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("server._settings", MagicMock()),
|
||||
patch("server._workspace_manager", None),
|
||||
patch("server._gitea_provider", mock_gitea),
|
||||
patch("server._github_provider", None),
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "degraded"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_with_gitea_error(self):
|
||||
"""Test health check when Gitea throws error."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
mock_gitea = AsyncMock()
|
||||
mock_gitea.is_connected = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
|
||||
with (
|
||||
patch("server._settings", MagicMock()),
|
||||
patch("server._workspace_manager", None),
|
||||
patch("server._gitea_provider", mock_gitea),
|
||||
patch("server._github_provider", None),
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "degraded"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_with_github_connected(self):
|
||||
"""Test health check with GitHub provider connected."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
mock_github = AsyncMock()
|
||||
mock_github.is_connected = AsyncMock(return_value=True)
|
||||
mock_github.get_authenticated_user = AsyncMock(return_value="github-user")
|
||||
|
||||
with (
|
||||
patch("server._settings", MagicMock()),
|
||||
patch("server._workspace_manager", None),
|
||||
patch("server._gitea_provider", None),
|
||||
patch("server._github_provider", mock_github),
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "github" in data["dependencies"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_with_github_not_connected(self):
|
||||
"""Test health check when GitHub is not connected."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
mock_github = AsyncMock()
|
||||
mock_github.is_connected = AsyncMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("server._settings", MagicMock()),
|
||||
patch("server._workspace_manager", None),
|
||||
patch("server._gitea_provider", None),
|
||||
patch("server._github_provider", mock_github),
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "degraded"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_with_github_error(self):
|
||||
"""Test health check when GitHub throws error."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
mock_github = AsyncMock()
|
||||
mock_github.is_connected = AsyncMock(side_effect=Exception("Auth failed"))
|
||||
|
||||
with (
|
||||
patch("server._settings", MagicMock()),
|
||||
patch("server._workspace_manager", None),
|
||||
patch("server._gitea_provider", None),
|
||||
patch("server._github_provider", mock_github),
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "degraded"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_with_workspace_manager(self):
|
||||
"""Test health check with workspace manager."""
|
||||
from pathlib import Path
|
||||
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.base_path = Path("/tmp/workspaces")
|
||||
mock_manager.list_workspaces = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch("server._settings", MagicMock()),
|
||||
patch("server._workspace_manager", mock_manager),
|
||||
patch("server._gitea_provider", None),
|
||||
patch("server._github_provider", None),
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "workspace" in data["dependencies"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_workspace_error(self):
|
||||
"""Test health check when workspace manager throws error."""
|
||||
from pathlib import Path
|
||||
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.base_path = Path("/tmp/workspaces")
|
||||
mock_manager.list_workspaces = AsyncMock(side_effect=Exception("Disk full"))
|
||||
|
||||
with (
|
||||
patch("server._settings", MagicMock()),
|
||||
patch("server._workspace_manager", mock_manager),
|
||||
patch("server._gitea_provider", None),
|
||||
patch("server._github_provider", None),
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/health")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "degraded"
|
||||
|
||||
|
||||
class TestMCPToolsEndpoint:
|
||||
"""Tests for MCP tools list endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_mcp_tools(self):
|
||||
"""Test listing MCP tools."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
with patch("server._settings", MagicMock()):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.get("/mcp/tools")
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "tools" in data
|
||||
|
||||
|
||||
class TestMCPRPCEndpoint:
|
||||
"""Tests for MCP JSON-RPC endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_rpc_invalid_json(self):
|
||||
"""Test RPC with invalid JSON."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/mcp",
|
||||
content="not valid json",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32700
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_rpc_invalid_jsonrpc(self):
|
||||
"""Test RPC with invalid jsonrpc version."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/mcp", json={"jsonrpc": "1.0", "method": "test", "id": 1}
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32600
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_rpc_missing_method(self):
|
||||
"""Test RPC with missing method."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.post("/mcp", json={"jsonrpc": "2.0", "id": 1})
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32600
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_rpc_method_not_found(self):
|
||||
"""Test RPC with unknown method."""
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from server import app
|
||||
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
response = await client.post(
|
||||
"/mcp",
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "unknown_method",
|
||||
"params": {},
|
||||
"id": 1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data["error"]["code"] == -32601
|
||||
|
||||
|
||||
class TestTypeSchemaConversion:
|
||||
"""Tests for type to JSON schema conversion."""
|
||||
|
||||
def test_python_type_to_json_schema_str(self):
|
||||
"""Test converting str type to JSON schema."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(str)
|
||||
assert result == {"type": "string"}
|
||||
|
||||
def test_python_type_to_json_schema_int(self):
|
||||
"""Test converting int type to JSON schema."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(int)
|
||||
assert result == {"type": "integer"}
|
||||
|
||||
def test_python_type_to_json_schema_float(self):
|
||||
"""Test converting float type to JSON schema."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(float)
|
||||
assert result == {"type": "number"}
|
||||
|
||||
def test_python_type_to_json_schema_bool(self):
|
||||
"""Test converting bool type to JSON schema."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(bool)
|
||||
assert result == {"type": "boolean"}
|
||||
|
||||
def test_python_type_to_json_schema_none(self):
|
||||
"""Test converting NoneType to JSON schema."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(type(None))
|
||||
assert result == {"type": "null"}
|
||||
|
||||
def test_python_type_to_json_schema_list(self):
|
||||
"""Test converting list type to JSON schema."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(list[str])
|
||||
assert result["type"] == "array"
|
||||
|
||||
def test_python_type_to_json_schema_dict(self):
|
||||
"""Test converting dict type to JSON schema."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(dict[str, int])
|
||||
assert result == {"type": "object"}
|
||||
|
||||
def test_python_type_to_json_schema_optional(self):
|
||||
"""Test converting Optional type to JSON schema."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(str | None)
|
||||
# The function returns object type for complex union types
|
||||
assert "type" in result
|
||||
|
||||
|
||||
class TestToolSchema:
|
||||
"""Tests for tool schema extraction."""
|
||||
|
||||
def test_get_tool_schema_simple(self):
|
||||
"""Test getting schema from simple function."""
|
||||
from server import _get_tool_schema
|
||||
|
||||
def simple_func(name: str, count: int) -> str:
|
||||
return f"{name}: {count}"
|
||||
|
||||
result = _get_tool_schema(simple_func)
|
||||
assert "properties" in result
|
||||
assert "name" in result["properties"]
|
||||
assert "count" in result["properties"]
|
||||
|
||||
def test_register_and_get_tool(self):
|
||||
"""Test registering a tool."""
|
||||
from server import _register_tool, _tool_registry
|
||||
|
||||
async def test_tool(x: str) -> str:
|
||||
"""A test tool."""
|
||||
return x
|
||||
|
||||
_register_tool("test_tool", test_tool, "Test description")
|
||||
|
||||
assert "test_tool" in _tool_registry
|
||||
assert _tool_registry["test_tool"]["description"] == "Test description"
|
||||
|
||||
# Clean up
|
||||
del _tool_registry["test_tool"]
|
||||
943
mcp-servers/git-ops/tests/test_git_wrapper.py
Normal file
943
mcp-servers/git-ops/tests/test_git_wrapper.py
Normal file
@@ -0,0 +1,943 @@
|
||||
"""
|
||||
Tests for the git_wrapper module.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from git import GitCommandError
|
||||
|
||||
from exceptions import (
|
||||
BranchExistsError,
|
||||
BranchNotFoundError,
|
||||
CheckoutError,
|
||||
CloneError,
|
||||
CommitError,
|
||||
GitError,
|
||||
PullError,
|
||||
PushError,
|
||||
)
|
||||
from git_wrapper import GitWrapper, run_in_executor
|
||||
from models import FileChangeType
|
||||
|
||||
|
||||
class TestGitWrapperInit:
|
||||
"""Tests for GitWrapper initialization."""
|
||||
|
||||
def test_init_with_valid_path(self, temp_workspace, test_settings):
|
||||
"""Test initialization with a valid path."""
|
||||
wrapper = GitWrapper(temp_workspace, test_settings)
|
||||
assert wrapper.workspace_path == temp_workspace
|
||||
assert wrapper.settings == test_settings
|
||||
|
||||
def test_repo_property_raises_on_non_git(self, temp_workspace, test_settings):
|
||||
"""Test that accessing repo on non-git dir raises error."""
|
||||
wrapper = GitWrapper(temp_workspace, test_settings)
|
||||
with pytest.raises(GitError, match="Not a git repository"):
|
||||
_ = wrapper.repo
|
||||
|
||||
def test_repo_property_works_on_git_dir(self, git_repo, test_settings):
|
||||
"""Test that repo property works for git directory."""
|
||||
wrapper = GitWrapper(Path(git_repo.working_dir), test_settings)
|
||||
assert wrapper.repo is not None
|
||||
assert wrapper.repo.head is not None
|
||||
|
||||
|
||||
class TestGitWrapperStatus:
|
||||
"""Tests for git status operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_clean_repo(self, git_wrapper_with_repo):
|
||||
"""Test status on a clean repository."""
|
||||
result = await git_wrapper_with_repo.status()
|
||||
|
||||
assert result.branch == "main"
|
||||
assert result.is_clean is True
|
||||
assert len(result.staged) == 0
|
||||
assert len(result.unstaged) == 0
|
||||
assert len(result.untracked) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_with_untracked(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test status with untracked files."""
|
||||
# Create untracked file
|
||||
untracked_file = Path(git_repo.working_dir) / "untracked.txt"
|
||||
untracked_file.write_text("untracked content")
|
||||
|
||||
result = await git_wrapper_with_repo.status()
|
||||
|
||||
assert result.is_clean is False
|
||||
assert "untracked.txt" in result.untracked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_with_modified(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test status with modified files."""
|
||||
# Modify existing file
|
||||
readme = Path(git_repo.working_dir) / "README.md"
|
||||
readme.write_text("# Modified content\n")
|
||||
|
||||
result = await git_wrapper_with_repo.status()
|
||||
|
||||
assert result.is_clean is False
|
||||
assert len(result.unstaged) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_with_staged(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test status with staged changes."""
|
||||
# Create and stage a file
|
||||
new_file = Path(git_repo.working_dir) / "staged.txt"
|
||||
new_file.write_text("staged content")
|
||||
git_repo.index.add(["staged.txt"])
|
||||
|
||||
result = await git_wrapper_with_repo.status()
|
||||
|
||||
assert result.is_clean is False
|
||||
assert len(result.staged) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_exclude_untracked(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test status without untracked files."""
|
||||
untracked_file = Path(git_repo.working_dir) / "untracked.txt"
|
||||
untracked_file.write_text("untracked")
|
||||
|
||||
result = await git_wrapper_with_repo.status(include_untracked=False)
|
||||
|
||||
assert len(result.untracked) == 0
|
||||
|
||||
|
||||
class TestGitWrapperBranch:
|
||||
"""Tests for branch operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_branch(self, git_wrapper_with_repo):
|
||||
"""Test creating a new branch."""
|
||||
result = await git_wrapper_with_repo.create_branch("feature-test")
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch == "feature-test"
|
||||
assert result.is_current is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_branch_without_checkout(self, git_wrapper_with_repo):
|
||||
"""Test creating branch without checkout."""
|
||||
result = await git_wrapper_with_repo.create_branch(
|
||||
"feature-no-checkout", checkout=False
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch == "feature-no-checkout"
|
||||
assert result.is_current is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_branch_exists_error(self, git_wrapper_with_repo):
|
||||
"""Test error when branch already exists."""
|
||||
await git_wrapper_with_repo.create_branch("existing-branch", checkout=False)
|
||||
|
||||
with pytest.raises(BranchExistsError):
|
||||
await git_wrapper_with_repo.create_branch("existing-branch")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_branch(self, git_wrapper_with_repo):
|
||||
"""Test deleting a branch."""
|
||||
# Create branch first
|
||||
await git_wrapper_with_repo.create_branch("to-delete", checkout=False)
|
||||
|
||||
# Delete it
|
||||
result = await git_wrapper_with_repo.delete_branch("to-delete")
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch == "to-delete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_branch_not_found(self, git_wrapper_with_repo):
|
||||
"""Test error when deleting non-existent branch."""
|
||||
with pytest.raises(BranchNotFoundError):
|
||||
await git_wrapper_with_repo.delete_branch("nonexistent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_current_branch_error(self, git_wrapper_with_repo):
|
||||
"""Test error when deleting current branch."""
|
||||
with pytest.raises(GitError, match="Cannot delete current branch"):
|
||||
await git_wrapper_with_repo.delete_branch("main")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_branches(self, git_wrapper_with_repo):
|
||||
"""Test listing branches."""
|
||||
# Create some branches
|
||||
await git_wrapper_with_repo.create_branch("branch-a", checkout=False)
|
||||
await git_wrapper_with_repo.create_branch("branch-b", checkout=False)
|
||||
|
||||
result = await git_wrapper_with_repo.list_branches()
|
||||
|
||||
assert result.current_branch == "main"
|
||||
branch_names = [b["name"] for b in result.local_branches]
|
||||
assert "main" in branch_names
|
||||
assert "branch-a" in branch_names
|
||||
assert "branch-b" in branch_names
|
||||
|
||||
|
||||
class TestGitWrapperCheckout:
|
||||
"""Tests for checkout operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_existing_branch(self, git_wrapper_with_repo):
|
||||
"""Test checkout of existing branch."""
|
||||
# Create branch first
|
||||
await git_wrapper_with_repo.create_branch("test-branch", checkout=False)
|
||||
|
||||
result = await git_wrapper_with_repo.checkout("test-branch")
|
||||
|
||||
assert result.success is True
|
||||
assert result.ref == "test-branch"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_create_new(self, git_wrapper_with_repo):
|
||||
"""Test checkout with branch creation."""
|
||||
result = await git_wrapper_with_repo.checkout("new-branch", create_branch=True)
|
||||
|
||||
assert result.success is True
|
||||
assert result.ref == "new-branch"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_nonexistent_error(self, git_wrapper_with_repo):
|
||||
"""Test error when checking out non-existent ref."""
|
||||
with pytest.raises(CheckoutError):
|
||||
await git_wrapper_with_repo.checkout("nonexistent-branch")
|
||||
|
||||
|
||||
class TestGitWrapperCommit:
|
||||
"""Tests for commit operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_staged_changes(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test committing staged changes."""
|
||||
# Create and stage a file
|
||||
new_file = Path(git_repo.working_dir) / "newfile.txt"
|
||||
new_file.write_text("new content")
|
||||
git_repo.index.add(["newfile.txt"])
|
||||
|
||||
result = await git_wrapper_with_repo.commit("Add new file")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message == "Add new file"
|
||||
assert result.files_changed == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_all_changes(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test committing all changes (auto-stage)."""
|
||||
# Create a file without staging
|
||||
new_file = Path(git_repo.working_dir) / "unstaged.txt"
|
||||
new_file.write_text("content")
|
||||
|
||||
result = await git_wrapper_with_repo.commit("Commit unstaged")
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_nothing_to_commit(self, git_wrapper_with_repo):
|
||||
"""Test error when nothing to commit."""
|
||||
with pytest.raises(CommitError, match="Nothing to commit"):
|
||||
await git_wrapper_with_repo.commit("Empty commit")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_with_author(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test commit with custom author."""
|
||||
new_file = Path(git_repo.working_dir) / "authored.txt"
|
||||
new_file.write_text("authored content")
|
||||
|
||||
result = await git_wrapper_with_repo.commit(
|
||||
"Custom author commit",
|
||||
author_name="Custom Author",
|
||||
author_email="custom@test.com",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestGitWrapperDiff:
|
||||
"""Tests for diff operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_no_changes(self, git_wrapper_with_repo):
|
||||
"""Test diff with no changes."""
|
||||
result = await git_wrapper_with_repo.diff()
|
||||
|
||||
assert result.files_changed == 0
|
||||
assert result.total_additions == 0
|
||||
assert result.total_deletions == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_with_changes(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test diff with modified files."""
|
||||
# Modify a file
|
||||
readme = Path(git_repo.working_dir) / "README.md"
|
||||
readme.write_text("# Modified\nNew line\n")
|
||||
|
||||
result = await git_wrapper_with_repo.diff()
|
||||
|
||||
assert result.files_changed > 0
|
||||
|
||||
|
||||
class TestGitWrapperLog:
|
||||
"""Tests for log operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_basic(self, git_wrapper_with_repo):
|
||||
"""Test basic log."""
|
||||
result = await git_wrapper_with_repo.log()
|
||||
|
||||
assert result.total_commits > 0
|
||||
assert len(result.commits) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_with_limit(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test log with limit."""
|
||||
# Create more commits
|
||||
for i in range(5):
|
||||
file_path = Path(git_repo.working_dir) / f"file{i}.txt"
|
||||
file_path.write_text(f"content {i}")
|
||||
git_repo.index.add([f"file{i}.txt"])
|
||||
git_repo.index.commit(f"Commit {i}")
|
||||
|
||||
result = await git_wrapper_with_repo.log(limit=3)
|
||||
|
||||
assert len(result.commits) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_commit_info(self, git_wrapper_with_repo):
|
||||
"""Test that log returns proper commit info."""
|
||||
result = await git_wrapper_with_repo.log(limit=1)
|
||||
|
||||
commit = result.commits[0]
|
||||
assert "sha" in commit
|
||||
assert "message" in commit
|
||||
assert "author_name" in commit
|
||||
assert "author_email" in commit
|
||||
|
||||
|
||||
class TestGitWrapperUtilities:
|
||||
"""Tests for utility methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_valid_ref_true(self, git_wrapper_with_repo):
|
||||
"""Test valid ref detection."""
|
||||
is_valid = await git_wrapper_with_repo.is_valid_ref("main")
|
||||
assert is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_valid_ref_false(self, git_wrapper_with_repo):
|
||||
"""Test invalid ref detection."""
|
||||
is_valid = await git_wrapper_with_repo.is_valid_ref("nonexistent")
|
||||
assert is_valid is False
|
||||
|
||||
def test_diff_to_change_type(self, git_wrapper_with_repo):
|
||||
"""Test change type conversion."""
|
||||
wrapper = git_wrapper_with_repo
|
||||
|
||||
assert wrapper._diff_to_change_type("A") == FileChangeType.ADDED
|
||||
assert wrapper._diff_to_change_type("M") == FileChangeType.MODIFIED
|
||||
assert wrapper._diff_to_change_type("D") == FileChangeType.DELETED
|
||||
assert wrapper._diff_to_change_type("R") == FileChangeType.RENAMED
|
||||
|
||||
|
||||
class TestGitWrapperStage:
|
||||
"""Tests for staging operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_specific_files(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test staging specific files."""
|
||||
# Create files
|
||||
file1 = Path(git_repo.working_dir) / "file1.txt"
|
||||
file2 = Path(git_repo.working_dir) / "file2.txt"
|
||||
file1.write_text("content 1")
|
||||
file2.write_text("content 2")
|
||||
|
||||
count = await git_wrapper_with_repo.stage(["file1.txt"])
|
||||
|
||||
assert count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage_all(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test staging all files."""
|
||||
file1 = Path(git_repo.working_dir) / "all1.txt"
|
||||
file2 = Path(git_repo.working_dir) / "all2.txt"
|
||||
file1.write_text("content 1")
|
||||
file2.write_text("content 2")
|
||||
|
||||
count = await git_wrapper_with_repo.stage()
|
||||
|
||||
assert count >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unstage_files(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test unstaging files."""
|
||||
# Create and stage file
|
||||
file1 = Path(git_repo.working_dir) / "unstage.txt"
|
||||
file1.write_text("to unstage")
|
||||
git_repo.index.add(["unstage.txt"])
|
||||
|
||||
count = await git_wrapper_with_repo.unstage()
|
||||
|
||||
assert count >= 1
|
||||
|
||||
|
||||
class TestGitWrapperReset:
|
||||
"""Tests for reset operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_soft(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test soft reset."""
|
||||
# Create a commit to reset
|
||||
file1 = Path(git_repo.working_dir) / "reset_soft.txt"
|
||||
file1.write_text("content")
|
||||
git_repo.index.add(["reset_soft.txt"])
|
||||
git_repo.index.commit("Commit to reset")
|
||||
|
||||
result = await git_wrapper_with_repo.reset("HEAD~1", mode="soft")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_mixed(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test mixed reset (default)."""
|
||||
file1 = Path(git_repo.working_dir) / "reset_mixed.txt"
|
||||
file1.write_text("content")
|
||||
git_repo.index.add(["reset_mixed.txt"])
|
||||
git_repo.index.commit("Commit to reset")
|
||||
|
||||
result = await git_wrapper_with_repo.reset("HEAD~1", mode="mixed")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_invalid_mode(self, git_wrapper_with_repo):
|
||||
"""Test error on invalid reset mode."""
|
||||
with pytest.raises(GitError, match="Invalid reset mode"):
|
||||
await git_wrapper_with_repo.reset("HEAD", mode="invalid")
|
||||
|
||||
|
||||
class TestGitWrapperStash:
|
||||
"""Tests for stash operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stash_changes(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test stashing changes."""
|
||||
# Make changes
|
||||
readme = Path(git_repo.working_dir) / "README.md"
|
||||
readme.write_text("Modified for stash")
|
||||
|
||||
result = await git_wrapper_with_repo.stash("Test stash")
|
||||
|
||||
# Result should be stash ref or None if nothing to stash
|
||||
# (depends on whether changes were already staged)
|
||||
assert result is None or result.startswith("stash@")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stash_nothing(self, git_wrapper_with_repo):
|
||||
"""Test stash with no changes."""
|
||||
result = await git_wrapper_with_repo.stash()
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stash_pop(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test popping a stash."""
|
||||
# Make changes and stash them
|
||||
readme = Path(git_repo.working_dir) / "README.md"
|
||||
original_content = readme.read_text()
|
||||
readme.write_text("Modified for stash pop test")
|
||||
git_repo.index.add(["README.md"])
|
||||
|
||||
stash_ref = await git_wrapper_with_repo.stash("Test stash for pop")
|
||||
|
||||
if stash_ref:
|
||||
# Pop the stash
|
||||
result = await git_wrapper_with_repo.stash_pop()
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestGitWrapperRepoProperty:
|
||||
"""Tests for repo property edge cases."""
|
||||
|
||||
def test_repo_property_path_not_exists(self, test_settings):
|
||||
"""Test that accessing repo on non-existent path raises error."""
|
||||
wrapper = GitWrapper(
|
||||
Path("/nonexistent/path/that/does/not/exist"), test_settings
|
||||
)
|
||||
with pytest.raises(GitError, match="Path does not exist"):
|
||||
_ = wrapper.repo
|
||||
|
||||
def test_refresh_repo(self, git_wrapper_with_repo):
|
||||
"""Test _refresh_repo clears cached repo."""
|
||||
# Access repo to cache it
|
||||
_ = git_wrapper_with_repo.repo
|
||||
assert git_wrapper_with_repo._repo is not None
|
||||
|
||||
# Refresh should clear it
|
||||
git_wrapper_with_repo._refresh_repo()
|
||||
assert git_wrapper_with_repo._repo is None
|
||||
|
||||
|
||||
class TestGitWrapperBranchAdvanced:
|
||||
"""Advanced tests for branch operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_branch_from_ref(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test creating branch from specific ref."""
|
||||
# Get current HEAD SHA
|
||||
head_sha = git_repo.head.commit.hexsha
|
||||
|
||||
result = await git_wrapper_with_repo.create_branch(
|
||||
"feature-from-ref",
|
||||
from_ref=head_sha,
|
||||
checkout=False,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch == "feature-from-ref"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_branch_force(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test force deleting a branch."""
|
||||
# Create branch and add unmerged commit
|
||||
await git_wrapper_with_repo.create_branch("unmerged-branch", checkout=True)
|
||||
new_file = Path(git_repo.working_dir) / "unmerged.txt"
|
||||
new_file.write_text("unmerged content")
|
||||
git_repo.index.add(["unmerged.txt"])
|
||||
git_repo.index.commit("Unmerged commit")
|
||||
|
||||
# Switch back to main
|
||||
await git_wrapper_with_repo.checkout("main")
|
||||
|
||||
# Force delete
|
||||
result = await git_wrapper_with_repo.delete_branch(
|
||||
"unmerged-branch", force=True
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestGitWrapperListBranchesRemote:
|
||||
"""Tests for listing remote branches."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_branches_with_remote(self, git_wrapper_with_repo):
|
||||
"""Test listing branches including remote."""
|
||||
# Even without remotes, this should work
|
||||
result = await git_wrapper_with_repo.list_branches(include_remote=True)
|
||||
|
||||
assert result.current_branch == "main"
|
||||
# Remote branches list should be empty for local repo
|
||||
assert len(result.remote_branches) == 0
|
||||
|
||||
|
||||
class TestGitWrapperCheckoutAdvanced:
|
||||
"""Advanced tests for checkout operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_create_existing_error(self, git_wrapper_with_repo):
|
||||
"""Test error when creating branch that already exists."""
|
||||
with pytest.raises(BranchExistsError):
|
||||
await git_wrapper_with_repo.checkout("main", create_branch=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_force(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test force checkout discards local changes."""
|
||||
# Create branch
|
||||
await git_wrapper_with_repo.create_branch("force-test", checkout=False)
|
||||
|
||||
# Make local changes
|
||||
readme = Path(git_repo.working_dir) / "README.md"
|
||||
readme.write_text("local changes")
|
||||
|
||||
# Force checkout should work
|
||||
result = await git_wrapper_with_repo.checkout("force-test", force=True)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestGitWrapperCommitAdvanced:
|
||||
"""Advanced tests for commit operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_specific_files(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test committing specific files only."""
|
||||
# Create multiple files
|
||||
file1 = Path(git_repo.working_dir) / "commit_specific1.txt"
|
||||
file2 = Path(git_repo.working_dir) / "commit_specific2.txt"
|
||||
file1.write_text("content 1")
|
||||
file2.write_text("content 2")
|
||||
|
||||
result = await git_wrapper_with_repo.commit(
|
||||
"Commit specific file",
|
||||
files=["commit_specific1.txt"],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.files_changed == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_with_partial_author(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test commit with only author name."""
|
||||
new_file = Path(git_repo.working_dir) / "partial_author.txt"
|
||||
new_file.write_text("content")
|
||||
|
||||
result = await git_wrapper_with_repo.commit(
|
||||
"Partial author commit",
|
||||
author_name="Test Author",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_allow_empty(self, git_wrapper_with_repo):
|
||||
"""Test allowing empty commits."""
|
||||
result = await git_wrapper_with_repo.commit(
|
||||
"Empty commit allowed",
|
||||
allow_empty=True,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestGitWrapperUnstageAdvanced:
|
||||
"""Advanced tests for unstaging operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unstage_specific_files(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test unstaging specific files."""
|
||||
# Create and stage files
|
||||
file1 = Path(git_repo.working_dir) / "unstage1.txt"
|
||||
file2 = Path(git_repo.working_dir) / "unstage2.txt"
|
||||
file1.write_text("content 1")
|
||||
file2.write_text("content 2")
|
||||
git_repo.index.add(["unstage1.txt", "unstage2.txt"])
|
||||
|
||||
count = await git_wrapper_with_repo.unstage(["unstage1.txt"])
|
||||
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestGitWrapperResetAdvanced:
|
||||
"""Advanced tests for reset operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_hard(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test hard reset."""
|
||||
# Create a commit
|
||||
file1 = Path(git_repo.working_dir) / "reset_hard.txt"
|
||||
file1.write_text("content")
|
||||
git_repo.index.add(["reset_hard.txt"])
|
||||
git_repo.index.commit("Commit for hard reset")
|
||||
|
||||
result = await git_wrapper_with_repo.reset("HEAD~1", mode="hard")
|
||||
|
||||
assert result is True
|
||||
# File should be gone after hard reset
|
||||
assert not file1.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset_specific_files(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test resetting specific files."""
|
||||
# Create and stage a file
|
||||
file1 = Path(git_repo.working_dir) / "reset_file.txt"
|
||||
file1.write_text("content")
|
||||
git_repo.index.add(["reset_file.txt"])
|
||||
|
||||
result = await git_wrapper_with_repo.reset("HEAD", files=["reset_file.txt"])
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestGitWrapperDiffAdvanced:
|
||||
"""Advanced tests for diff operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_between_refs(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test diff between two refs."""
|
||||
# Create initial commit
|
||||
file1 = Path(git_repo.working_dir) / "diff_ref.txt"
|
||||
file1.write_text("initial")
|
||||
git_repo.index.add(["diff_ref.txt"])
|
||||
commit1 = git_repo.index.commit("First commit for diff")
|
||||
|
||||
# Create second commit
|
||||
file1.write_text("modified")
|
||||
git_repo.index.add(["diff_ref.txt"])
|
||||
commit2 = git_repo.index.commit("Second commit for diff")
|
||||
|
||||
result = await git_wrapper_with_repo.diff(
|
||||
base=commit1.hexsha,
|
||||
head=commit2.hexsha,
|
||||
)
|
||||
|
||||
assert result.files_changed > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_specific_files(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test diff for specific files only."""
|
||||
# Create files
|
||||
file1 = Path(git_repo.working_dir) / "diff_specific1.txt"
|
||||
file2 = Path(git_repo.working_dir) / "diff_specific2.txt"
|
||||
file1.write_text("content 1")
|
||||
file2.write_text("content 2")
|
||||
|
||||
result = await git_wrapper_with_repo.diff(files=["diff_specific1.txt"])
|
||||
|
||||
# Should only show changes for specified file
|
||||
for f in result.files:
|
||||
assert "diff_specific2.txt" not in f.get("path", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_base_only(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test diff with base ref only (vs HEAD)."""
|
||||
# Create commit
|
||||
file1 = Path(git_repo.working_dir) / "diff_base.txt"
|
||||
file1.write_text("content")
|
||||
git_repo.index.add(["diff_base.txt"])
|
||||
commit = git_repo.index.commit("Commit for diff base test")
|
||||
|
||||
# Get parent commit
|
||||
parent = commit.parents[0] if commit.parents else commit
|
||||
|
||||
result = await git_wrapper_with_repo.diff(base=parent.hexsha)
|
||||
|
||||
assert isinstance(result.files_changed, int)
|
||||
|
||||
|
||||
class TestGitWrapperLogAdvanced:
|
||||
"""Advanced tests for log operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_with_ref(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test log starting from specific ref."""
|
||||
# Create branch with commits
|
||||
await git_wrapper_with_repo.create_branch("log-test", checkout=True)
|
||||
file1 = Path(git_repo.working_dir) / "log_ref.txt"
|
||||
file1.write_text("content")
|
||||
git_repo.index.add(["log_ref.txt"])
|
||||
git_repo.index.commit("Commit on log-test branch")
|
||||
|
||||
result = await git_wrapper_with_repo.log(ref="log-test", limit=5)
|
||||
|
||||
assert result.total_commits > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_with_path(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test log filtered by path."""
|
||||
# Create file and commit
|
||||
file1 = Path(git_repo.working_dir) / "log_path.txt"
|
||||
file1.write_text("content")
|
||||
git_repo.index.add(["log_path.txt"])
|
||||
git_repo.index.commit("Commit for path log")
|
||||
|
||||
result = await git_wrapper_with_repo.log(path="log_path.txt")
|
||||
|
||||
assert result.total_commits >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_with_skip(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test log with skip parameter."""
|
||||
# Create multiple commits
|
||||
for i in range(3):
|
||||
file_path = Path(git_repo.working_dir) / f"skip_test{i}.txt"
|
||||
file_path.write_text(f"content {i}")
|
||||
git_repo.index.add([f"skip_test{i}.txt"])
|
||||
git_repo.index.commit(f"Skip test commit {i}")
|
||||
|
||||
result = await git_wrapper_with_repo.log(skip=1, limit=2)
|
||||
|
||||
# Should have skipped first commit
|
||||
assert len(result.commits) <= 2
|
||||
|
||||
|
||||
class TestGitWrapperRemoteUrl:
|
||||
"""Tests for remote URL operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_remote_url_nonexistent(self, git_wrapper_with_repo):
|
||||
"""Test getting URL for non-existent remote."""
|
||||
url = await git_wrapper_with_repo.get_remote_url("nonexistent")
|
||||
|
||||
assert url is None
|
||||
|
||||
|
||||
class TestGitWrapperConfig:
|
||||
"""Tests for git config operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_config(self, git_wrapper_with_repo):
|
||||
"""Test setting and getting config value."""
|
||||
await git_wrapper_with_repo.set_config("test.key", "test_value")
|
||||
|
||||
value = await git_wrapper_with_repo.get_config("test.key")
|
||||
|
||||
assert value == "test_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_config_nonexistent(self, git_wrapper_with_repo):
|
||||
"""Test getting non-existent config value."""
|
||||
value = await git_wrapper_with_repo.get_config("nonexistent.key")
|
||||
|
||||
assert value is None
|
||||
|
||||
|
||||
class TestGitWrapperClone:
|
||||
"""Tests for clone operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_success(self, temp_workspace, test_settings):
|
||||
"""Test successful clone."""
|
||||
wrapper = GitWrapper(temp_workspace, test_settings)
|
||||
|
||||
# Mock the clone operation
|
||||
with patch("git_wrapper.GitRepo") as mock_repo_class:
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.active_branch.name = "main"
|
||||
mock_repo.head.commit.hexsha = "abc123"
|
||||
mock_repo_class.clone_from.return_value = mock_repo
|
||||
|
||||
result = await wrapper.clone("https://github.com/test/repo.git")
|
||||
|
||||
assert result.success is True
|
||||
assert result.branch == "main"
|
||||
assert result.commit_sha == "abc123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_with_auth_token(self, temp_workspace, test_settings):
|
||||
"""Test clone with auth token."""
|
||||
wrapper = GitWrapper(temp_workspace, test_settings)
|
||||
|
||||
with patch("git_wrapper.GitRepo") as mock_repo_class:
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.active_branch.name = "main"
|
||||
mock_repo.head.commit.hexsha = "abc123"
|
||||
mock_repo_class.clone_from.return_value = mock_repo
|
||||
|
||||
result = await wrapper.clone(
|
||||
"https://github.com/test/repo.git",
|
||||
auth_token="test-token",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
# Verify token was injected in URL
|
||||
call_args = mock_repo_class.clone_from.call_args
|
||||
assert "test-token@" in call_args.kwargs["url"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_with_branch_and_depth(self, temp_workspace, test_settings):
|
||||
"""Test clone with branch and depth parameters."""
|
||||
wrapper = GitWrapper(temp_workspace, test_settings)
|
||||
|
||||
with patch("git_wrapper.GitRepo") as mock_repo_class:
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.active_branch.name = "develop"
|
||||
mock_repo.head.commit.hexsha = "def456"
|
||||
mock_repo_class.clone_from.return_value = mock_repo
|
||||
|
||||
result = await wrapper.clone(
|
||||
"https://github.com/test/repo.git",
|
||||
branch="develop",
|
||||
depth=1,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
call_args = mock_repo_class.clone_from.call_args
|
||||
assert call_args.kwargs["branch"] == "develop"
|
||||
assert call_args.kwargs["depth"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_failure(self, temp_workspace, test_settings):
|
||||
"""Test clone failure raises CloneError."""
|
||||
wrapper = GitWrapper(temp_workspace, test_settings)
|
||||
|
||||
with patch("git_wrapper.GitRepo") as mock_repo_class:
|
||||
mock_repo_class.clone_from.side_effect = GitCommandError(
|
||||
"git clone", 128, stderr="Authentication failed"
|
||||
)
|
||||
|
||||
with pytest.raises(CloneError):
|
||||
await wrapper.clone("https://github.com/test/repo.git")
|
||||
|
||||
|
||||
class TestGitWrapperPush:
|
||||
"""Tests for push operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_force_disabled(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test force push is disabled by default."""
|
||||
git_repo.create_remote("origin", "https://github.com/test/repo.git")
|
||||
|
||||
with pytest.raises(PushError, match="Force push is disabled"):
|
||||
await git_wrapper_with_repo.push(force=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_remote_not_found(self, git_wrapper_with_repo):
|
||||
"""Test push to non-existent remote."""
|
||||
with pytest.raises(PushError, match="Remote not found"):
|
||||
await git_wrapper_with_repo.push(remote="nonexistent")
|
||||
|
||||
|
||||
class TestGitWrapperPull:
|
||||
"""Tests for pull operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pull_remote_not_found(self, git_wrapper_with_repo):
|
||||
"""Test pull from non-existent remote."""
|
||||
with pytest.raises(PullError, match="Remote not found"):
|
||||
await git_wrapper_with_repo.pull(remote="nonexistent")
|
||||
|
||||
|
||||
class TestGitWrapperFetch:
|
||||
"""Tests for fetch operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_remote_not_found(self, git_wrapper_with_repo):
|
||||
"""Test fetch from non-existent remote."""
|
||||
with pytest.raises(GitError, match="Remote not found"):
|
||||
await git_wrapper_with_repo.fetch(remote="nonexistent")
|
||||
|
||||
|
||||
class TestGitWrapperDiffHeadOnly:
|
||||
"""Tests for diff with head ref only."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_head_only(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test diff with head ref only (working tree vs ref)."""
|
||||
# Make some changes
|
||||
readme = Path(git_repo.working_dir) / "README.md"
|
||||
readme.write_text("modified content")
|
||||
|
||||
# This tests the head-only branch (base=None, head=specified)
|
||||
result = await git_wrapper_with_repo.diff(head="HEAD")
|
||||
|
||||
assert isinstance(result.files_changed, int)
|
||||
|
||||
|
||||
class TestGitWrapperRemoteWithUrl:
|
||||
"""Tests for getting remote URL when remote exists."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_remote_url_exists(self, git_wrapper_with_repo, git_repo):
|
||||
"""Test getting URL for existing remote."""
|
||||
git_repo.create_remote("origin", "https://github.com/test/repo.git")
|
||||
|
||||
url = await git_wrapper_with_repo.get_remote_url("origin")
|
||||
|
||||
assert url == "https://github.com/test/repo.git"
|
||||
|
||||
|
||||
class TestRunInExecutor:
|
||||
"""Tests for run_in_executor utility."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_in_executor(self):
|
||||
"""Test running function in executor."""
|
||||
|
||||
def blocking_func(x, y):
|
||||
return x + y
|
||||
|
||||
result = await run_in_executor(blocking_func, 1, 2)
|
||||
|
||||
assert result == 3
|
||||
620
mcp-servers/git-ops/tests/test_github_provider.py
Normal file
620
mcp-servers/git-ops/tests/test_github_provider.py
Normal file
@@ -0,0 +1,620 @@
|
||||
"""
|
||||
Tests for GitHub provider implementation.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import APIError, AuthenticationError
|
||||
from models import MergeStrategy, PRState
|
||||
from providers.github import GitHubProvider
|
||||
|
||||
|
||||
class TestGitHubProviderBasics:
|
||||
"""Tests for GitHubProvider basic operations."""
|
||||
|
||||
def test_provider_name(self):
|
||||
"""Test provider name is github."""
|
||||
provider = GitHubProvider(token="test-token")
|
||||
assert provider.name == "github"
|
||||
|
||||
def test_parse_repo_url_https(self):
|
||||
"""Test parsing HTTPS repo URL."""
|
||||
provider = GitHubProvider(token="test-token")
|
||||
|
||||
owner, repo = provider.parse_repo_url("https://github.com/owner/repo.git")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_https_no_git(self):
|
||||
"""Test parsing HTTPS URL without .git suffix."""
|
||||
provider = GitHubProvider(token="test-token")
|
||||
|
||||
owner, repo = provider.parse_repo_url("https://github.com/owner/repo")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_ssh(self):
|
||||
"""Test parsing SSH repo URL."""
|
||||
provider = GitHubProvider(token="test-token")
|
||||
|
||||
owner, repo = provider.parse_repo_url("git@github.com:owner/repo.git")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_invalid(self):
|
||||
"""Test error on invalid URL."""
|
||||
provider = GitHubProvider(token="test-token")
|
||||
|
||||
with pytest.raises(ValueError, match="Unable to parse"):
|
||||
provider.parse_repo_url("invalid-url")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_github_httpx_client():
|
||||
"""Create a mock httpx client for GitHub provider tests."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json = MagicMock(return_value={})
|
||||
mock_response.text = ""
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.post = AsyncMock(return_value=mock_response)
|
||||
mock_client.patch = AsyncMock(return_value=mock_response)
|
||||
mock_client.put = AsyncMock(return_value=mock_response)
|
||||
mock_client.delete = AsyncMock(return_value=mock_response)
|
||||
|
||||
return mock_client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def github_provider(test_settings, mock_github_httpx_client):
|
||||
"""Create a GitHubProvider with mocked HTTP client."""
|
||||
provider = GitHubProvider(
|
||||
token=test_settings.github_token,
|
||||
settings=test_settings,
|
||||
)
|
||||
provider._client = mock_github_httpx_client
|
||||
|
||||
yield provider
|
||||
|
||||
await provider.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_pr_data():
|
||||
"""Sample PR data from GitHub API."""
|
||||
return {
|
||||
"number": 42,
|
||||
"title": "Test PR",
|
||||
"body": "This is a test pull request",
|
||||
"state": "open",
|
||||
"head": {"ref": "feature-branch"},
|
||||
"base": {"ref": "main"},
|
||||
"user": {"login": "test-user"},
|
||||
"created_at": "2024-01-15T10:00:00Z",
|
||||
"updated_at": "2024-01-15T12:00:00Z",
|
||||
"merged_at": None,
|
||||
"closed_at": None,
|
||||
"html_url": "https://github.com/owner/repo/pull/42",
|
||||
"labels": [{"name": "enhancement"}],
|
||||
"assignees": [{"login": "assignee1"}],
|
||||
"requested_reviewers": [{"login": "reviewer1"}],
|
||||
"mergeable": True,
|
||||
"draft": False,
|
||||
}
|
||||
|
||||
|
||||
class TestGitHubProviderConnection:
|
||||
"""Tests for GitHub provider connection."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_connected(self, github_provider, mock_github_httpx_client):
|
||||
"""Test connection check."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"login": "test-user"}
|
||||
)
|
||||
|
||||
result = await github_provider.is_connected()
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_connected_no_token(self, test_settings):
|
||||
"""Test connection fails without token."""
|
||||
provider = GitHubProvider(
|
||||
token="",
|
||||
settings=test_settings,
|
||||
)
|
||||
|
||||
result = await provider.is_connected()
|
||||
assert result is False
|
||||
|
||||
await provider.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_authenticated_user(
|
||||
self, github_provider, mock_github_httpx_client
|
||||
):
|
||||
"""Test getting authenticated user."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"login": "test-user"}
|
||||
)
|
||||
|
||||
user = await github_provider.get_authenticated_user()
|
||||
|
||||
assert user == "test-user"
|
||||
|
||||
|
||||
class TestGitHubProviderRepoOperations:
|
||||
"""Tests for GitHub repository operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_repo_info(self, github_provider, mock_github_httpx_client):
|
||||
"""Test getting repository info."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"name": "repo",
|
||||
"full_name": "owner/repo",
|
||||
"default_branch": "main",
|
||||
}
|
||||
)
|
||||
|
||||
result = await github_provider.get_repo_info("owner", "repo")
|
||||
|
||||
assert result["name"] == "repo"
|
||||
assert result["default_branch"] == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_default_branch(self, github_provider, mock_github_httpx_client):
|
||||
"""Test getting default branch."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"default_branch": "develop"}
|
||||
)
|
||||
|
||||
branch = await github_provider.get_default_branch("owner", "repo")
|
||||
|
||||
assert branch == "develop"
|
||||
|
||||
|
||||
class TestGitHubPROperations:
|
||||
"""Tests for GitHub PR operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr(self, github_provider, mock_github_httpx_client):
|
||||
"""Test creating a pull request."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"number": 42,
|
||||
"html_url": "https://github.com/owner/repo/pull/42",
|
||||
}
|
||||
)
|
||||
|
||||
result = await github_provider.create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="Test PR",
|
||||
body="Test body",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.pr_number == 42
|
||||
assert result.pr_url == "https://github.com/owner/repo/pull/42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr_with_draft(
|
||||
self, github_provider, mock_github_httpx_client
|
||||
):
|
||||
"""Test creating a draft PR."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"number": 43,
|
||||
"html_url": "https://github.com/owner/repo/pull/43",
|
||||
}
|
||||
)
|
||||
|
||||
result = await github_provider.create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="Draft PR",
|
||||
body="Draft body",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
draft=True,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.pr_number == 43
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr_with_options(
|
||||
self, github_provider, mock_github_httpx_client
|
||||
):
|
||||
"""Test creating PR with labels, assignees, reviewers."""
|
||||
mock_responses = [
|
||||
{
|
||||
"number": 44,
|
||||
"html_url": "https://github.com/owner/repo/pull/44",
|
||||
}, # Create PR
|
||||
[{"name": "enhancement"}], # POST add labels
|
||||
{}, # POST add assignees
|
||||
{}, # POST request reviewers
|
||||
]
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
side_effect=mock_responses
|
||||
)
|
||||
|
||||
result = await github_provider.create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="Test PR",
|
||||
body="Test body",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
labels=["enhancement"],
|
||||
assignees=["user1"],
|
||||
reviewers=["reviewer1"],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pr(
|
||||
self, github_provider, mock_github_httpx_client, github_pr_data
|
||||
):
|
||||
"""Test getting a pull request."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=github_pr_data
|
||||
)
|
||||
|
||||
result = await github_provider.get_pr("owner", "repo", 42)
|
||||
|
||||
assert result.success is True
|
||||
assert result.pr["number"] == 42
|
||||
assert result.pr["title"] == "Test PR"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pr_not_found(self, github_provider, mock_github_httpx_client):
|
||||
"""Test getting non-existent PR."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 404
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=None
|
||||
)
|
||||
|
||||
result = await github_provider.get_pr("owner", "repo", 999)
|
||||
|
||||
assert result.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_prs(
|
||||
self, github_provider, mock_github_httpx_client, github_pr_data
|
||||
):
|
||||
"""Test listing pull requests."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[github_pr_data, github_pr_data]
|
||||
)
|
||||
|
||||
result = await github_provider.list_prs("owner", "repo")
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.pull_requests) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_prs_with_state_filter(
|
||||
self, github_provider, mock_github_httpx_client, github_pr_data
|
||||
):
|
||||
"""Test listing PRs with state filter."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[github_pr_data]
|
||||
)
|
||||
|
||||
result = await github_provider.list_prs("owner", "repo", state=PRState.OPEN)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr(
|
||||
self, github_provider, mock_github_httpx_client, github_pr_data
|
||||
):
|
||||
"""Test merging a pull request."""
|
||||
# Merge returns sha, then get_pr returns the PR data, then delete branch
|
||||
mock_responses = [
|
||||
{"sha": "merge-commit-sha", "merged": True}, # PUT merge
|
||||
github_pr_data, # GET PR for branch info
|
||||
None, # DELETE branch
|
||||
]
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
side_effect=mock_responses
|
||||
)
|
||||
|
||||
result = await github_provider.merge_pr(
|
||||
"owner",
|
||||
"repo",
|
||||
42,
|
||||
merge_strategy=MergeStrategy.SQUASH,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.merge_commit_sha == "merge-commit-sha"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr_rebase(
|
||||
self, github_provider, mock_github_httpx_client, github_pr_data
|
||||
):
|
||||
"""Test merging with rebase strategy."""
|
||||
mock_responses = [
|
||||
{"sha": "rebase-commit-sha", "merged": True}, # PUT merge
|
||||
github_pr_data, # GET PR for branch info
|
||||
None, # DELETE branch
|
||||
]
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
side_effect=mock_responses
|
||||
)
|
||||
|
||||
result = await github_provider.merge_pr(
|
||||
"owner",
|
||||
"repo",
|
||||
42,
|
||||
merge_strategy=MergeStrategy.REBASE,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pr(
|
||||
self, github_provider, mock_github_httpx_client, github_pr_data
|
||||
):
|
||||
"""Test updating a pull request."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=github_pr_data
|
||||
)
|
||||
|
||||
result = await github_provider.update_pr(
|
||||
"owner",
|
||||
"repo",
|
||||
42,
|
||||
title="Updated Title",
|
||||
body="Updated body",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_pr(
|
||||
self, github_provider, mock_github_httpx_client, github_pr_data
|
||||
):
|
||||
"""Test closing a pull request."""
|
||||
github_pr_data["state"] = "closed"
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=github_pr_data
|
||||
)
|
||||
|
||||
result = await github_provider.close_pr("owner", "repo", 42)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestGitHubBranchOperations:
|
||||
"""Tests for GitHub branch operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_branch(self, github_provider, mock_github_httpx_client):
|
||||
"""Test getting branch info."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"name": "main",
|
||||
"commit": {"sha": "abc123"},
|
||||
}
|
||||
)
|
||||
|
||||
result = await github_provider.get_branch("owner", "repo", "main")
|
||||
|
||||
assert result["name"] == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_remote_branch(
|
||||
self, github_provider, mock_github_httpx_client
|
||||
):
|
||||
"""Test deleting a remote branch."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 204
|
||||
|
||||
result = await github_provider.delete_remote_branch(
|
||||
"owner", "repo", "old-branch"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestGitHubCommentOperations:
|
||||
"""Tests for GitHub comment operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_pr_comment(self, github_provider, mock_github_httpx_client):
|
||||
"""Test adding a comment to a PR."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"id": 1, "body": "Test comment"}
|
||||
)
|
||||
|
||||
result = await github_provider.add_pr_comment(
|
||||
"owner", "repo", 42, "Test comment"
|
||||
)
|
||||
|
||||
assert result["body"] == "Test comment"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_pr_comments(self, github_provider, mock_github_httpx_client):
|
||||
"""Test listing PR comments."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[
|
||||
{"id": 1, "body": "Comment 1"},
|
||||
{"id": 2, "body": "Comment 2"},
|
||||
]
|
||||
)
|
||||
|
||||
result = await github_provider.list_pr_comments("owner", "repo", 42)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestGitHubLabelOperations:
|
||||
"""Tests for GitHub label operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_labels(self, github_provider, mock_github_httpx_client):
|
||||
"""Test adding labels to a PR."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[{"name": "bug"}, {"name": "urgent"}]
|
||||
)
|
||||
|
||||
result = await github_provider.add_labels(
|
||||
"owner", "repo", 42, ["bug", "urgent"]
|
||||
)
|
||||
|
||||
assert "bug" in result
|
||||
assert "urgent" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_label(self, github_provider, mock_github_httpx_client):
|
||||
"""Test removing a label from a PR."""
|
||||
mock_responses = [
|
||||
None, # DELETE label
|
||||
{"labels": []}, # GET issue
|
||||
]
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
side_effect=mock_responses
|
||||
)
|
||||
|
||||
result = await github_provider.remove_label("owner", "repo", 42, "bug")
|
||||
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestGitHubReviewerOperations:
|
||||
"""Tests for GitHub reviewer operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_review(self, github_provider, mock_github_httpx_client):
|
||||
"""Test requesting review from users."""
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(return_value={})
|
||||
|
||||
result = await github_provider.request_review(
|
||||
"owner", "repo", 42, ["reviewer1", "reviewer2"]
|
||||
)
|
||||
|
||||
assert result == ["reviewer1", "reviewer2"]
|
||||
|
||||
|
||||
class TestGitHubErrorHandling:
|
||||
"""Tests for error handling in GitHub provider."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_error(
|
||||
self, github_provider, mock_github_httpx_client
|
||||
):
|
||||
"""Test handling authentication errors."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 401
|
||||
|
||||
with pytest.raises(AuthenticationError):
|
||||
await github_provider._request("GET", "/user")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_denied(self, github_provider, mock_github_httpx_client):
|
||||
"""Test handling permission denied errors."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 403
|
||||
mock_github_httpx_client.request.return_value.text = "Permission denied"
|
||||
|
||||
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
|
||||
await github_provider._request("GET", "/protected")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_error(self, github_provider, mock_github_httpx_client):
|
||||
"""Test handling rate limit errors."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 403
|
||||
mock_github_httpx_client.request.return_value.text = "API rate limit exceeded"
|
||||
|
||||
with pytest.raises(APIError, match="rate limit"):
|
||||
await github_provider._request("GET", "/user")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error(self, github_provider, mock_github_httpx_client):
|
||||
"""Test handling general API errors."""
|
||||
mock_github_httpx_client.request.return_value.status_code = 500
|
||||
mock_github_httpx_client.request.return_value.text = "Internal Server Error"
|
||||
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"message": "Server error"}
|
||||
)
|
||||
|
||||
with pytest.raises(APIError):
|
||||
await github_provider._request("GET", "/error")
|
||||
|
||||
|
||||
class TestGitHubPRParsing:
|
||||
"""Tests for PR data parsing."""
|
||||
|
||||
def test_parse_pr_open(self, github_provider, github_pr_data):
|
||||
"""Test parsing open PR."""
|
||||
pr_info = github_provider._parse_pr(github_pr_data)
|
||||
|
||||
assert pr_info.number == 42
|
||||
assert pr_info.state == PRState.OPEN
|
||||
assert pr_info.title == "Test PR"
|
||||
assert pr_info.source_branch == "feature-branch"
|
||||
assert pr_info.target_branch == "main"
|
||||
|
||||
def test_parse_pr_merged(self, github_provider, github_pr_data):
|
||||
"""Test parsing merged PR."""
|
||||
github_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
|
||||
|
||||
pr_info = github_provider._parse_pr(github_pr_data)
|
||||
|
||||
assert pr_info.state == PRState.MERGED
|
||||
|
||||
def test_parse_pr_closed(self, github_provider, github_pr_data):
|
||||
"""Test parsing closed PR."""
|
||||
github_pr_data["state"] = "closed"
|
||||
github_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
|
||||
|
||||
pr_info = github_provider._parse_pr(github_pr_data)
|
||||
|
||||
assert pr_info.state == PRState.CLOSED
|
||||
|
||||
def test_parse_pr_draft(self, github_provider, github_pr_data):
|
||||
"""Test parsing draft PR."""
|
||||
github_pr_data["draft"] = True
|
||||
|
||||
pr_info = github_provider._parse_pr(github_pr_data)
|
||||
|
||||
assert pr_info.draft is True
|
||||
|
||||
def test_parse_datetime_iso(self, github_provider):
|
||||
"""Test parsing ISO datetime strings."""
|
||||
dt = github_provider._parse_datetime("2024-01-15T10:30:00Z")
|
||||
|
||||
assert dt.year == 2024
|
||||
assert dt.month == 1
|
||||
assert dt.day == 15
|
||||
|
||||
def test_parse_datetime_none(self, github_provider):
|
||||
"""Test parsing None datetime returns now."""
|
||||
dt = github_provider._parse_datetime(None)
|
||||
|
||||
assert dt is not None
|
||||
assert dt.tzinfo is not None
|
||||
|
||||
def test_parse_pr_with_null_body(self, github_provider, github_pr_data):
|
||||
"""Test parsing PR with null body."""
|
||||
github_pr_data["body"] = None
|
||||
|
||||
pr_info = github_provider._parse_pr(github_pr_data)
|
||||
|
||||
assert pr_info.body == ""
|
||||
486
mcp-servers/git-ops/tests/test_providers.py
Normal file
486
mcp-servers/git-ops/tests/test_providers.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Tests for git provider implementations.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import APIError, AuthenticationError
|
||||
from models import MergeStrategy, PRState
|
||||
from providers.gitea import GiteaProvider
|
||||
|
||||
|
||||
class TestBaseProvider:
|
||||
"""Tests for BaseProvider interface."""
|
||||
|
||||
def test_parse_repo_url_https(self, mock_gitea_provider):
|
||||
"""Test parsing HTTPS repo URL."""
|
||||
# The mock needs parse_repo_url to work
|
||||
provider = GiteaProvider(base_url="https://gitea.test.com", token="test-token")
|
||||
|
||||
owner, repo = provider.parse_repo_url("https://gitea.test.com/owner/repo.git")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_https_no_git(self):
|
||||
"""Test parsing HTTPS URL without .git suffix."""
|
||||
provider = GiteaProvider(base_url="https://gitea.test.com", token="test-token")
|
||||
|
||||
owner, repo = provider.parse_repo_url("https://gitea.test.com/owner/repo")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_ssh(self):
|
||||
"""Test parsing SSH repo URL."""
|
||||
provider = GiteaProvider(base_url="https://gitea.test.com", token="test-token")
|
||||
|
||||
owner, repo = provider.parse_repo_url("git@gitea.test.com:owner/repo.git")
|
||||
|
||||
assert owner == "owner"
|
||||
assert repo == "repo"
|
||||
|
||||
def test_parse_repo_url_invalid(self):
|
||||
"""Test error on invalid URL."""
|
||||
provider = GiteaProvider(base_url="https://gitea.test.com", token="test-token")
|
||||
|
||||
with pytest.raises(ValueError, match="Unable to parse"):
|
||||
provider.parse_repo_url("invalid-url")
|
||||
|
||||
|
||||
class TestGiteaProvider:
|
||||
"""Tests for GiteaProvider."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_connected(self, gitea_provider, mock_httpx_client):
|
||||
"""Test connection check."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"login": "test-user"}
|
||||
)
|
||||
|
||||
result = await gitea_provider.is_connected()
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_connected_no_token(self, test_settings):
|
||||
"""Test connection fails without token."""
|
||||
provider = GiteaProvider(
|
||||
base_url="https://gitea.test.com",
|
||||
token="",
|
||||
settings=test_settings,
|
||||
)
|
||||
|
||||
result = await provider.is_connected()
|
||||
assert result is False
|
||||
|
||||
await provider.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_authenticated_user(self, gitea_provider, mock_httpx_client):
|
||||
"""Test getting authenticated user."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"login": "test-user"}
|
||||
)
|
||||
|
||||
user = await gitea_provider.get_authenticated_user()
|
||||
|
||||
assert user == "test-user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_repo_info(self, gitea_provider, mock_httpx_client):
|
||||
"""Test getting repository info."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"name": "repo",
|
||||
"full_name": "owner/repo",
|
||||
"default_branch": "main",
|
||||
}
|
||||
)
|
||||
|
||||
result = await gitea_provider.get_repo_info("owner", "repo")
|
||||
|
||||
assert result["name"] == "repo"
|
||||
assert result["default_branch"] == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_default_branch(self, gitea_provider, mock_httpx_client):
|
||||
"""Test getting default branch."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"default_branch": "develop"}
|
||||
)
|
||||
|
||||
branch = await gitea_provider.get_default_branch("owner", "repo")
|
||||
|
||||
assert branch == "develop"
|
||||
|
||||
|
||||
class TestGiteaPROperations:
|
||||
"""Tests for Gitea PR operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr(self, gitea_provider, mock_httpx_client):
|
||||
"""Test creating a pull request."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"number": 42,
|
||||
"html_url": "https://gitea.test.com/owner/repo/pull/42",
|
||||
}
|
||||
)
|
||||
|
||||
result = await gitea_provider.create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="Test PR",
|
||||
body="Test body",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.pr_number == 42
|
||||
assert result.pr_url == "https://gitea.test.com/owner/repo/pull/42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr_with_options(self, gitea_provider, mock_httpx_client):
|
||||
"""Test creating PR with labels, assignees, reviewers."""
|
||||
# Use side_effect for multiple API calls:
|
||||
# 1. POST create PR
|
||||
# 2. GET labels (for "enhancement") - in add_labels -> _get_or_create_label
|
||||
# 3. POST add labels to PR - in add_labels
|
||||
# 4. GET issue to return labels - in add_labels
|
||||
# 5. PATCH add assignees
|
||||
# 6. POST request reviewers
|
||||
mock_responses = [
|
||||
{
|
||||
"number": 43,
|
||||
"html_url": "https://gitea.test.com/owner/repo/pull/43",
|
||||
}, # Create PR
|
||||
[{"id": 1, "name": "enhancement"}], # GET labels (found)
|
||||
{}, # POST add labels to PR
|
||||
{"labels": [{"name": "enhancement"}]}, # GET issue to return current labels
|
||||
{}, # PATCH add assignees
|
||||
{}, # POST request reviewers
|
||||
]
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
side_effect=mock_responses
|
||||
)
|
||||
|
||||
result = await gitea_provider.create_pr(
|
||||
owner="owner",
|
||||
repo="repo",
|
||||
title="Test PR",
|
||||
body="Test body",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
labels=["enhancement"],
|
||||
assignees=["user1"],
|
||||
reviewers=["reviewer1"],
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||
"""Test getting a pull request."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=sample_pr_data
|
||||
)
|
||||
|
||||
result = await gitea_provider.get_pr("owner", "repo", 42)
|
||||
|
||||
assert result.success is True
|
||||
assert result.pr["number"] == 42
|
||||
assert result.pr["title"] == "Test PR"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pr_not_found(self, gitea_provider, mock_httpx_client):
|
||||
"""Test getting non-existent PR."""
|
||||
mock_httpx_client.request.return_value.status_code = 404
|
||||
mock_httpx_client.request.return_value.json = MagicMock(return_value=None)
|
||||
|
||||
result = await gitea_provider.get_pr("owner", "repo", 999)
|
||||
|
||||
assert result.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_prs(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||
"""Test listing pull requests."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[sample_pr_data, sample_pr_data]
|
||||
)
|
||||
|
||||
result = await gitea_provider.list_prs("owner", "repo")
|
||||
|
||||
assert result.success is True
|
||||
assert len(result.pull_requests) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_prs_with_state_filter(
|
||||
self, gitea_provider, mock_httpx_client, sample_pr_data
|
||||
):
|
||||
"""Test listing PRs with state filter."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[sample_pr_data]
|
||||
)
|
||||
|
||||
result = await gitea_provider.list_prs("owner", "repo", state=PRState.OPEN)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr(self, gitea_provider, mock_httpx_client):
|
||||
"""Test merging a pull request."""
|
||||
# First call returns merge result
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"sha": "merge-commit-sha"}
|
||||
)
|
||||
|
||||
result = await gitea_provider.merge_pr(
|
||||
"owner",
|
||||
"repo",
|
||||
42,
|
||||
merge_strategy=MergeStrategy.SQUASH,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.merge_commit_sha == "merge-commit-sha"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||
"""Test updating a pull request."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=sample_pr_data
|
||||
)
|
||||
|
||||
result = await gitea_provider.update_pr(
|
||||
"owner",
|
||||
"repo",
|
||||
42,
|
||||
title="Updated Title",
|
||||
body="Updated body",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||
"""Test closing a pull request."""
|
||||
sample_pr_data["state"] = "closed"
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=sample_pr_data
|
||||
)
|
||||
|
||||
result = await gitea_provider.close_pr("owner", "repo", 42)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
|
||||
class TestGiteaBranchOperations:
|
||||
"""Tests for Gitea branch operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_branch(self, gitea_provider, mock_httpx_client):
|
||||
"""Test getting branch info."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={
|
||||
"name": "main",
|
||||
"commit": {"sha": "abc123"},
|
||||
}
|
||||
)
|
||||
|
||||
result = await gitea_provider.get_branch("owner", "repo", "main")
|
||||
|
||||
assert result["name"] == "main"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_remote_branch(self, gitea_provider, mock_httpx_client):
|
||||
"""Test deleting a remote branch."""
|
||||
mock_httpx_client.request.return_value.status_code = 204
|
||||
|
||||
result = await gitea_provider.delete_remote_branch(
|
||||
"owner", "repo", "old-branch"
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestGiteaCommentOperations:
|
||||
"""Tests for Gitea comment operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_pr_comment(self, gitea_provider, mock_httpx_client):
|
||||
"""Test adding a comment to a PR."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"id": 1, "body": "Test comment"}
|
||||
)
|
||||
|
||||
result = await gitea_provider.add_pr_comment(
|
||||
"owner", "repo", 42, "Test comment"
|
||||
)
|
||||
|
||||
assert result["body"] == "Test comment"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_pr_comments(self, gitea_provider, mock_httpx_client):
|
||||
"""Test listing PR comments."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value=[
|
||||
{"id": 1, "body": "Comment 1"},
|
||||
{"id": 2, "body": "Comment 2"},
|
||||
]
|
||||
)
|
||||
|
||||
result = await gitea_provider.list_pr_comments("owner", "repo", 42)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestGiteaLabelOperations:
|
||||
"""Tests for Gitea label operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_labels(self, gitea_provider, mock_httpx_client):
|
||||
"""Test adding labels to a PR."""
|
||||
# Use side_effect to return different values for different calls
|
||||
# 1. GET labels (for "bug") - returns existing labels
|
||||
# 2. POST to create "bug" label
|
||||
# 3. GET labels (for "urgent")
|
||||
# 4. POST to create "urgent" label
|
||||
# 5. POST labels to PR
|
||||
# 6. GET issue to return final labels
|
||||
mock_responses = [
|
||||
[{"id": 1, "name": "existing"}], # GET labels (bug not found)
|
||||
{"id": 2, "name": "bug"}, # POST create bug
|
||||
[
|
||||
{"id": 1, "name": "existing"},
|
||||
{"id": 2, "name": "bug"},
|
||||
], # GET labels (urgent not found)
|
||||
{"id": 3, "name": "urgent"}, # POST create urgent
|
||||
{}, # POST add labels to PR
|
||||
{"labels": [{"name": "bug"}, {"name": "urgent"}]}, # GET issue
|
||||
]
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
side_effect=mock_responses
|
||||
)
|
||||
|
||||
result = await gitea_provider.add_labels("owner", "repo", 42, ["bug", "urgent"])
|
||||
|
||||
# Should return updated label list
|
||||
assert isinstance(result, list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_label(self, gitea_provider, mock_httpx_client):
|
||||
"""Test removing a label from a PR."""
|
||||
# Use side_effect for multiple calls
|
||||
# 1. GET labels to find the label ID
|
||||
# 2. DELETE the label from the PR
|
||||
# 3. GET issue to return remaining labels
|
||||
mock_responses = [
|
||||
[{"id": 1, "name": "bug"}], # GET labels
|
||||
{}, # DELETE label
|
||||
{"labels": []}, # GET issue
|
||||
]
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
side_effect=mock_responses
|
||||
)
|
||||
|
||||
result = await gitea_provider.remove_label("owner", "repo", 42, "bug")
|
||||
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestGiteaReviewerOperations:
|
||||
"""Tests for Gitea reviewer operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_review(self, gitea_provider, mock_httpx_client):
|
||||
"""Test requesting review from users."""
|
||||
mock_httpx_client.request.return_value.json = MagicMock(return_value={})
|
||||
|
||||
result = await gitea_provider.request_review(
|
||||
"owner", "repo", 42, ["reviewer1", "reviewer2"]
|
||||
)
|
||||
|
||||
assert result == ["reviewer1", "reviewer2"]
|
||||
|
||||
|
||||
class TestGiteaErrorHandling:
|
||||
"""Tests for error handling in Gitea provider."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_error(self, gitea_provider, mock_httpx_client):
|
||||
"""Test handling authentication errors."""
|
||||
mock_httpx_client.request.return_value.status_code = 401
|
||||
|
||||
with pytest.raises(AuthenticationError):
|
||||
await gitea_provider._request("GET", "/user")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_denied(self, gitea_provider, mock_httpx_client):
|
||||
"""Test handling permission denied errors."""
|
||||
mock_httpx_client.request.return_value.status_code = 403
|
||||
|
||||
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
|
||||
await gitea_provider._request("GET", "/protected")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error(self, gitea_provider, mock_httpx_client):
|
||||
"""Test handling general API errors."""
|
||||
mock_httpx_client.request.return_value.status_code = 500
|
||||
mock_httpx_client.request.return_value.text = "Internal Server Error"
|
||||
mock_httpx_client.request.return_value.json = MagicMock(
|
||||
return_value={"message": "Server error"}
|
||||
)
|
||||
|
||||
with pytest.raises(APIError):
|
||||
await gitea_provider._request("GET", "/error")
|
||||
|
||||
|
||||
class TestGiteaPRParsing:
|
||||
"""Tests for PR data parsing."""
|
||||
|
||||
def test_parse_pr_open(self, gitea_provider, sample_pr_data):
|
||||
"""Test parsing open PR."""
|
||||
pr_info = gitea_provider._parse_pr(sample_pr_data)
|
||||
|
||||
assert pr_info.number == 42
|
||||
assert pr_info.state == PRState.OPEN
|
||||
assert pr_info.title == "Test PR"
|
||||
assert pr_info.source_branch == "feature-branch"
|
||||
assert pr_info.target_branch == "main"
|
||||
|
||||
def test_parse_pr_merged(self, gitea_provider, sample_pr_data):
|
||||
"""Test parsing merged PR."""
|
||||
sample_pr_data["merged"] = True
|
||||
sample_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
|
||||
|
||||
pr_info = gitea_provider._parse_pr(sample_pr_data)
|
||||
|
||||
assert pr_info.state == PRState.MERGED
|
||||
|
||||
def test_parse_pr_closed(self, gitea_provider, sample_pr_data):
|
||||
"""Test parsing closed PR."""
|
||||
sample_pr_data["state"] = "closed"
|
||||
sample_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
|
||||
|
||||
pr_info = gitea_provider._parse_pr(sample_pr_data)
|
||||
|
||||
assert pr_info.state == PRState.CLOSED
|
||||
|
||||
def test_parse_datetime_iso(self, gitea_provider):
|
||||
"""Test parsing ISO datetime strings."""
|
||||
dt = gitea_provider._parse_datetime("2024-01-15T10:30:00Z")
|
||||
|
||||
assert dt.year == 2024
|
||||
assert dt.month == 1
|
||||
assert dt.day == 15
|
||||
|
||||
def test_parse_datetime_none(self, gitea_provider):
|
||||
"""Test parsing None datetime returns now."""
|
||||
dt = gitea_provider._parse_datetime(None)
|
||||
|
||||
assert dt is not None
|
||||
assert dt.tzinfo is not None
|
||||
522
mcp-servers/git-ops/tests/test_server.py
Normal file
522
mcp-servers/git-ops/tests/test_server.py
Normal file
@@ -0,0 +1,522 @@
|
||||
"""
|
||||
Tests for the MCP server and tools.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import ErrorCode
|
||||
|
||||
|
||||
class TestInputValidation:
|
||||
"""Tests for input validation functions."""
|
||||
|
||||
def test_validate_id_valid(self):
|
||||
"""Test valid IDs pass validation."""
|
||||
from server import _validate_id
|
||||
|
||||
assert _validate_id("test-123", "project_id") is None
|
||||
assert _validate_id("my_project", "project_id") is None
|
||||
assert _validate_id("Agent-001", "agent_id") is None
|
||||
|
||||
def test_validate_id_empty(self):
|
||||
"""Test empty ID fails validation."""
|
||||
from server import _validate_id
|
||||
|
||||
error = _validate_id("", "project_id")
|
||||
assert error is not None
|
||||
assert "required" in error.lower()
|
||||
|
||||
def test_validate_id_too_long(self):
|
||||
"""Test too-long ID fails validation."""
|
||||
from server import _validate_id
|
||||
|
||||
error = _validate_id("a" * 200, "project_id")
|
||||
assert error is not None
|
||||
assert "1-128" in error
|
||||
|
||||
def test_validate_id_invalid_chars(self):
|
||||
"""Test invalid characters fail validation."""
|
||||
from server import _validate_id
|
||||
|
||||
assert _validate_id("test@invalid", "project_id") is not None
|
||||
assert _validate_id("test!project", "project_id") is not None
|
||||
assert _validate_id("test project", "project_id") is not None
|
||||
|
||||
def test_validate_branch_valid(self):
|
||||
"""Test valid branch names."""
|
||||
from server import _validate_branch
|
||||
|
||||
assert _validate_branch("main") is None
|
||||
assert _validate_branch("feature/new-thing") is None
|
||||
assert _validate_branch("release-1.0.0") is None
|
||||
assert _validate_branch("hotfix.urgent") is None
|
||||
|
||||
def test_validate_branch_invalid(self):
|
||||
"""Test invalid branch names."""
|
||||
from server import _validate_branch
|
||||
|
||||
assert _validate_branch("") is not None
|
||||
assert _validate_branch("a" * 300) is not None
|
||||
|
||||
def test_validate_url_valid(self):
|
||||
"""Test valid repository URLs."""
|
||||
from server import _validate_url
|
||||
|
||||
assert _validate_url("https://github.com/owner/repo.git") is None
|
||||
assert _validate_url("https://gitea.example.com/owner/repo") is None
|
||||
assert _validate_url("git@github.com:owner/repo.git") is None
|
||||
|
||||
def test_validate_url_invalid(self):
|
||||
"""Test invalid repository URLs."""
|
||||
from server import _validate_url
|
||||
|
||||
assert _validate_url("") is not None
|
||||
assert _validate_url("not-a-url") is not None
|
||||
assert _validate_url("ftp://invalid.com/repo") is not None
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for health check endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_structure(self):
|
||||
"""Test health check returns proper structure."""
|
||||
from server import health_check
|
||||
|
||||
with (
|
||||
patch("server._gitea_provider", None),
|
||||
patch("server._workspace_manager", None),
|
||||
):
|
||||
result = await health_check()
|
||||
|
||||
assert "status" in result
|
||||
assert "service" in result
|
||||
assert "version" in result
|
||||
assert "timestamp" in result
|
||||
assert "dependencies" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_no_providers(self):
|
||||
"""Test health check without providers configured."""
|
||||
from server import health_check
|
||||
|
||||
with (
|
||||
patch("server._gitea_provider", None),
|
||||
patch("server._workspace_manager", None),
|
||||
):
|
||||
result = await health_check()
|
||||
|
||||
assert result["dependencies"]["gitea"] == "not configured"
|
||||
|
||||
|
||||
class TestToolRegistry:
|
||||
"""Tests for tool registration."""
|
||||
|
||||
def test_tool_registry_populated(self):
|
||||
"""Test that tools are registered."""
|
||||
from server import _tool_registry
|
||||
|
||||
assert len(_tool_registry) > 0
|
||||
assert "clone_repository" in _tool_registry
|
||||
assert "git_status" in _tool_registry
|
||||
assert "create_branch" in _tool_registry
|
||||
assert "commit" in _tool_registry
|
||||
|
||||
def test_tool_schema_structure(self):
|
||||
"""Test tool schemas have proper structure."""
|
||||
from server import _tool_registry
|
||||
|
||||
for name, info in _tool_registry.items():
|
||||
assert "func" in info
|
||||
assert "description" in info
|
||||
assert "schema" in info
|
||||
assert info["schema"]["type"] == "object"
|
||||
assert "properties" in info["schema"]
|
||||
|
||||
|
||||
class TestCloneRepository:
|
||||
"""Tests for clone_repository tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_invalid_project_id(self):
|
||||
"""Test clone with invalid project ID."""
|
||||
from server import clone_repository
|
||||
|
||||
# Access the underlying function via .fn
|
||||
result = await clone_repository.fn(
|
||||
project_id="invalid@id",
|
||||
agent_id="agent-1",
|
||||
repo_url="https://github.com/owner/repo.git",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "project_id" in result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clone_invalid_repo_url(self):
|
||||
"""Test clone with invalid repo URL."""
|
||||
from server import clone_repository
|
||||
|
||||
result = await clone_repository.fn(
|
||||
project_id="valid-project",
|
||||
agent_id="agent-1",
|
||||
repo_url="not-a-valid-url",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "url" in result["error"].lower()
|
||||
|
||||
|
||||
class TestGitStatus:
|
||||
"""Tests for git_status tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_workspace_not_found(self):
|
||||
"""Test status when workspace doesn't exist."""
|
||||
from server import git_status
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await git_status.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["code"] == ErrorCode.WORKSPACE_NOT_FOUND.value
|
||||
|
||||
|
||||
class TestBranchOperations:
|
||||
"""Tests for branch operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_branch_invalid_name(self):
|
||||
"""Test creating branch with invalid name."""
|
||||
from server import create_branch
|
||||
|
||||
result = await create_branch.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
branch_name="", # Invalid
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_branches_workspace_not_found(self):
|
||||
"""Test listing branches when workspace doesn't exist."""
|
||||
from server import list_branches
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await list_branches.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkout_invalid_project(self):
|
||||
"""Test checkout with invalid project ID."""
|
||||
from server import checkout
|
||||
|
||||
result = await checkout.fn(
|
||||
project_id="inv@lid",
|
||||
agent_id="agent-1",
|
||||
ref="main",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
class TestCommitOperations:
|
||||
"""Tests for commit operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_invalid_project(self):
|
||||
"""Test commit with invalid project ID."""
|
||||
from server import commit
|
||||
|
||||
result = await commit.fn(
|
||||
project_id="inv@lid",
|
||||
agent_id="agent-1",
|
||||
message="Test commit",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
class TestPushPullOperations:
|
||||
"""Tests for push/pull operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_push_workspace_not_found(self):
|
||||
"""Test push when workspace doesn't exist."""
|
||||
from server import push
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await push.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pull_workspace_not_found(self):
|
||||
"""Test pull when workspace doesn't exist."""
|
||||
from server import pull
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await pull.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
class TestDiffLogOperations:
|
||||
"""Tests for diff and log operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_workspace_not_found(self):
|
||||
"""Test diff when workspace doesn't exist."""
|
||||
from server import diff
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await diff.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_log_workspace_not_found(self):
|
||||
"""Test log when workspace doesn't exist."""
|
||||
from server import log
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await log.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
|
||||
class TestPROperations:
|
||||
"""Tests for pull request operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_pr_no_repo_url(self):
|
||||
"""Test create PR when workspace has no repo URL."""
|
||||
from models import WorkspaceInfo, WorkspaceState
|
||||
from server import create_pull_request
|
||||
|
||||
mock_workspace = WorkspaceInfo(
|
||||
project_id="test-project",
|
||||
path="/tmp/test",
|
||||
state=WorkspaceState.READY,
|
||||
repo_url=None, # No repo URL
|
||||
)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await create_pull_request.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
title="Test PR",
|
||||
source_branch="feature",
|
||||
target_branch="main",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "repository URL" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_prs_invalid_state(self):
|
||||
"""Test list PRs with invalid state filter."""
|
||||
from models import WorkspaceInfo, WorkspaceState
|
||||
from server import list_pull_requests
|
||||
|
||||
mock_workspace = WorkspaceInfo(
|
||||
project_id="test-project",
|
||||
path="/tmp/test",
|
||||
state=WorkspaceState.READY,
|
||||
repo_url="https://gitea.test.com/owner/repo.git",
|
||||
)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
|
||||
|
||||
with (
|
||||
patch("server._workspace_manager", mock_manager),
|
||||
patch("server._get_provider_for_url", return_value=mock_provider),
|
||||
):
|
||||
result = await list_pull_requests.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
state="invalid-state",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid state" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_pr_invalid_strategy(self):
|
||||
"""Test merge PR with invalid strategy."""
|
||||
from models import WorkspaceInfo, WorkspaceState
|
||||
from server import merge_pull_request
|
||||
|
||||
mock_workspace = WorkspaceInfo(
|
||||
project_id="test-project",
|
||||
path="/tmp/test",
|
||||
state=WorkspaceState.READY,
|
||||
repo_url="https://gitea.test.com/owner/repo.git",
|
||||
)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
|
||||
|
||||
with (
|
||||
patch("server._workspace_manager", mock_manager),
|
||||
patch("server._get_provider_for_url", return_value=mock_provider),
|
||||
):
|
||||
result = await merge_pull_request.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
pr_number=42,
|
||||
merge_strategy="invalid-strategy",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Invalid strategy" in result["error"]
|
||||
|
||||
|
||||
class TestWorkspaceOperations:
|
||||
"""Tests for workspace operation tools."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_not_found(self):
|
||||
"""Test get workspace when it doesn't exist."""
|
||||
from server import get_workspace
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await get_workspace.fn(
|
||||
project_id="nonexistent",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_workspace_success(self):
|
||||
"""Test successful workspace locking."""
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from models import WorkspaceInfo, WorkspaceState
|
||||
from server import lock_workspace
|
||||
|
||||
mock_workspace = WorkspaceInfo(
|
||||
project_id="test-project",
|
||||
path="/tmp/test",
|
||||
state=WorkspaceState.LOCKED,
|
||||
lock_holder="agent-1",
|
||||
lock_expires=datetime.now(UTC) + timedelta(seconds=300),
|
||||
)
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.lock_workspace = AsyncMock(return_value=True)
|
||||
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await lock_workspace.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["lock_holder"] == "agent-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlock_workspace_success(self):
|
||||
"""Test successful workspace unlocking."""
|
||||
from server import unlock_workspace
|
||||
|
||||
mock_manager = AsyncMock()
|
||||
mock_manager.unlock_workspace = AsyncMock(return_value=True)
|
||||
|
||||
with patch("server._workspace_manager", mock_manager):
|
||||
result = await unlock_workspace.fn(
|
||||
project_id="test-project",
|
||||
agent_id="agent-1",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
class TestJSONRPCEndpoint:
|
||||
"""Tests for the JSON-RPC endpoint."""
|
||||
|
||||
def test_python_type_to_json_schema_str(self):
|
||||
"""Test string type conversion."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(str)
|
||||
assert result["type"] == "string"
|
||||
|
||||
def test_python_type_to_json_schema_int(self):
|
||||
"""Test int type conversion."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(int)
|
||||
assert result["type"] == "integer"
|
||||
|
||||
def test_python_type_to_json_schema_bool(self):
|
||||
"""Test bool type conversion."""
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(bool)
|
||||
assert result["type"] == "boolean"
|
||||
|
||||
def test_python_type_to_json_schema_list(self):
|
||||
"""Test list type conversion."""
|
||||
|
||||
from server import _python_type_to_json_schema
|
||||
|
||||
result = _python_type_to_json_schema(list[str])
|
||||
assert result["type"] == "array"
|
||||
assert result["items"]["type"] == "string"
|
||||
1170
mcp-servers/git-ops/tests/test_server_tools.py
Normal file
1170
mcp-servers/git-ops/tests/test_server_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
358
mcp-servers/git-ops/tests/test_workspace.py
Normal file
358
mcp-servers/git-ops/tests/test_workspace.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
Tests for the workspace management module.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exceptions import WorkspaceLockedError, WorkspaceNotFoundError
|
||||
from models import WorkspaceState
|
||||
from workspace import FileLockManager, WorkspaceLock
|
||||
|
||||
|
||||
class TestWorkspaceManager:
|
||||
"""Tests for WorkspaceManager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workspace(self, workspace_manager, valid_project_id):
|
||||
"""Test creating a new workspace."""
|
||||
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
assert workspace.project_id == valid_project_id
|
||||
assert workspace.state == WorkspaceState.INITIALIZING
|
||||
assert Path(workspace.path).exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_workspace_with_repo_url(
|
||||
self, workspace_manager, valid_project_id, sample_repo_url
|
||||
):
|
||||
"""Test creating workspace with repository URL."""
|
||||
workspace = await workspace_manager.create_workspace(
|
||||
valid_project_id, repo_url=sample_repo_url
|
||||
)
|
||||
|
||||
assert workspace.repo_url == sample_repo_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace(self, workspace_manager, valid_project_id):
|
||||
"""Test getting an existing workspace."""
|
||||
# Create first
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
# Get it
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
|
||||
assert workspace is not None
|
||||
assert workspace.project_id == valid_project_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_workspace_not_found(self, workspace_manager):
|
||||
"""Test getting non-existent workspace."""
|
||||
workspace = await workspace_manager.get_workspace("nonexistent")
|
||||
assert workspace is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_workspace(self, workspace_manager, valid_project_id):
|
||||
"""Test deleting a workspace."""
|
||||
# Create first
|
||||
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||
workspace_path = Path(workspace.path)
|
||||
assert workspace_path.exists()
|
||||
|
||||
# Delete
|
||||
result = await workspace_manager.delete_workspace(valid_project_id)
|
||||
|
||||
assert result is True
|
||||
assert not workspace_path.exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_workspace(self, workspace_manager):
|
||||
"""Test deleting non-existent workspace returns True."""
|
||||
result = await workspace_manager.delete_workspace("nonexistent")
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_workspaces(self, workspace_manager):
|
||||
"""Test listing workspaces."""
|
||||
# Create multiple workspaces
|
||||
await workspace_manager.create_workspace("project-1")
|
||||
await workspace_manager.create_workspace("project-2")
|
||||
await workspace_manager.create_workspace("project-3")
|
||||
|
||||
workspaces = await workspace_manager.list_workspaces()
|
||||
|
||||
assert len(workspaces) >= 3
|
||||
project_ids = [w.project_id for w in workspaces]
|
||||
assert "project-1" in project_ids
|
||||
assert "project-2" in project_ids
|
||||
assert "project-3" in project_ids
|
||||
|
||||
|
||||
class TestWorkspaceLocking:
|
||||
"""Tests for workspace locking."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_workspace(
|
||||
self, workspace_manager, valid_project_id, valid_agent_id
|
||||
):
|
||||
"""Test locking a workspace."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
result = await workspace_manager.lock_workspace(
|
||||
valid_project_id, valid_agent_id, timeout=60
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.state == WorkspaceState.LOCKED
|
||||
assert workspace.lock_holder == valid_agent_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_already_locked(self, workspace_manager, valid_project_id):
|
||||
"""Test locking already-locked workspace by different holder."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, "agent-1", timeout=60)
|
||||
|
||||
with pytest.raises(WorkspaceLockedError):
|
||||
await workspace_manager.lock_workspace(
|
||||
valid_project_id, "agent-2", timeout=60
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_same_holder(
|
||||
self, workspace_manager, valid_project_id, valid_agent_id
|
||||
):
|
||||
"""Test re-locking by same holder extends lock."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(
|
||||
valid_project_id, valid_agent_id, timeout=60
|
||||
)
|
||||
|
||||
# Same holder can re-lock
|
||||
result = await workspace_manager.lock_workspace(
|
||||
valid_project_id, valid_agent_id, timeout=120
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlock_workspace(
|
||||
self, workspace_manager, valid_project_id, valid_agent_id
|
||||
):
|
||||
"""Test unlocking a workspace."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
|
||||
|
||||
result = await workspace_manager.unlock_workspace(
|
||||
valid_project_id, valid_agent_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.state == WorkspaceState.READY
|
||||
assert workspace.lock_holder is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unlock_wrong_holder(self, workspace_manager, valid_project_id):
|
||||
"""Test unlock fails with wrong holder."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, "agent-1")
|
||||
|
||||
with pytest.raises(WorkspaceLockedError):
|
||||
await workspace_manager.unlock_workspace(valid_project_id, "agent-2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_force_unlock(self, workspace_manager, valid_project_id):
|
||||
"""Test force unlock works regardless of holder."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, "agent-1")
|
||||
|
||||
result = await workspace_manager.unlock_workspace(
|
||||
valid_project_id, "admin", force=True
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_nonexistent_workspace(self, workspace_manager, valid_agent_id):
|
||||
"""Test locking non-existent workspace raises error."""
|
||||
with pytest.raises(WorkspaceNotFoundError):
|
||||
await workspace_manager.lock_workspace("nonexistent", valid_agent_id)
|
||||
|
||||
|
||||
class TestWorkspaceLockContextManager:
|
||||
"""Tests for WorkspaceLock context manager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_context_manager(
|
||||
self, workspace_manager, valid_project_id, valid_agent_id
|
||||
):
|
||||
"""Test using WorkspaceLock as context manager."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
async with WorkspaceLock(
|
||||
workspace_manager, valid_project_id, valid_agent_id
|
||||
) as lock:
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.state == WorkspaceState.LOCKED
|
||||
|
||||
# After exiting context, should be unlocked
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.lock_holder is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_context_manager_error(
|
||||
self, workspace_manager, valid_project_id, valid_agent_id
|
||||
):
|
||||
"""Test WorkspaceLock releases on exception."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
try:
|
||||
async with WorkspaceLock(
|
||||
workspace_manager, valid_project_id, valid_agent_id
|
||||
):
|
||||
raise ValueError("Test error")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.lock_holder is None
|
||||
|
||||
|
||||
class TestWorkspaceMetadata:
|
||||
"""Tests for workspace metadata operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_touch_workspace(self, workspace_manager, valid_project_id):
|
||||
"""Test updating workspace access time."""
|
||||
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||
original_time = workspace.last_accessed
|
||||
|
||||
await workspace_manager.touch_workspace(valid_project_id)
|
||||
|
||||
updated = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert updated.last_accessed >= original_time
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_workspace_branch(self, workspace_manager, valid_project_id):
|
||||
"""Test updating workspace branch."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
await workspace_manager.update_workspace_branch(
|
||||
valid_project_id, "feature-branch"
|
||||
)
|
||||
|
||||
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||
assert workspace.current_branch == "feature-branch"
|
||||
|
||||
|
||||
class TestWorkspaceSize:
|
||||
"""Tests for workspace size management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_size_within_limit(self, workspace_manager, valid_project_id):
|
||||
"""Test size check passes for small workspace."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
# Should not raise
|
||||
result = await workspace_manager.check_size_limit(valid_project_id)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_total_size(self, workspace_manager, valid_project_id):
|
||||
"""Test getting total workspace size."""
|
||||
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||
|
||||
# Add some content
|
||||
content_file = Path(workspace.path) / "content.txt"
|
||||
content_file.write_text("x" * 1000)
|
||||
|
||||
total_size = await workspace_manager.get_total_size()
|
||||
assert total_size >= 1000
|
||||
|
||||
|
||||
class TestFileLockManager:
|
||||
"""Tests for file-based locking."""
|
||||
|
||||
def test_acquire_lock(self, temp_dir):
|
||||
"""Test acquiring a file lock."""
|
||||
manager = FileLockManager(temp_dir / "locks")
|
||||
|
||||
result = manager.acquire("test-key")
|
||||
assert result is True
|
||||
|
||||
# Cleanup
|
||||
manager.release("test-key")
|
||||
|
||||
def test_release_lock(self, temp_dir):
|
||||
"""Test releasing a file lock."""
|
||||
manager = FileLockManager(temp_dir / "locks")
|
||||
manager.acquire("test-key")
|
||||
|
||||
result = manager.release("test-key")
|
||||
assert result is True
|
||||
|
||||
def test_is_locked(self, temp_dir):
|
||||
"""Test checking if locked."""
|
||||
manager = FileLockManager(temp_dir / "locks")
|
||||
|
||||
assert manager.is_locked("test-key") is False
|
||||
|
||||
manager.acquire("test-key")
|
||||
assert manager.is_locked("test-key") is True
|
||||
|
||||
manager.release("test-key")
|
||||
|
||||
def test_release_nonexistent_lock(self, temp_dir):
|
||||
"""Test releasing a lock that doesn't exist."""
|
||||
manager = FileLockManager(temp_dir / "locks")
|
||||
|
||||
# Should not raise
|
||||
result = manager.release("nonexistent")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestWorkspaceCleanup:
|
||||
"""Tests for workspace cleanup operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_stale_workspaces(self, workspace_manager, test_settings):
|
||||
"""Test cleaning up stale workspaces."""
|
||||
# Create workspace
|
||||
workspace = await workspace_manager.create_workspace("stale-project")
|
||||
|
||||
# Manually set it as stale by updating metadata
|
||||
await workspace_manager._update_metadata(
|
||||
"stale-project",
|
||||
last_accessed=(datetime.now(UTC) - timedelta(days=30)).isoformat(),
|
||||
)
|
||||
|
||||
# Run cleanup
|
||||
cleaned = await workspace_manager.cleanup_stale_workspaces()
|
||||
|
||||
assert cleaned >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_locked_workspace_blocked(
|
||||
self, workspace_manager, valid_project_id, valid_agent_id
|
||||
):
|
||||
"""Test deleting locked workspace is blocked without force."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
|
||||
|
||||
with pytest.raises(WorkspaceLockedError):
|
||||
await workspace_manager.delete_workspace(valid_project_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_locked_workspace_force(
|
||||
self, workspace_manager, valid_project_id, valid_agent_id
|
||||
):
|
||||
"""Test force deleting locked workspace."""
|
||||
await workspace_manager.create_workspace(valid_project_id)
|
||||
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
|
||||
|
||||
result = await workspace_manager.delete_workspace(valid_project_id, force=True)
|
||||
assert result is True
|
||||
1853
mcp-servers/git-ops/uv.lock
generated
Normal file
1853
mcp-servers/git-ops/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
614
mcp-servers/git-ops/workspace.py
Normal file
614
mcp-servers/git-ops/workspace.py
Normal file
@@ -0,0 +1,614 @@
|
||||
"""
|
||||
Workspace management for Git Operations MCP Server.
|
||||
|
||||
Handles isolated workspaces for each project, including creation,
|
||||
locking, cleanup, and size management.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiofiles # type: ignore[import-untyped]
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from config import Settings, get_settings
|
||||
from exceptions import (
|
||||
WorkspaceLockedError,
|
||||
WorkspaceNotFoundError,
|
||||
WorkspaceSizeExceededError,
|
||||
)
|
||||
from models import WorkspaceInfo, WorkspaceState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Metadata file name
|
||||
WORKSPACE_METADATA_FILE = ".syndarix-workspace.json"
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
"""
|
||||
Manages git workspaces for projects.
|
||||
|
||||
Each project gets an isolated workspace directory for git operations.
|
||||
Supports distributed locking via Redis or local file locks.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: Settings | None = None) -> None:
|
||||
"""
|
||||
Initialize WorkspaceManager.
|
||||
|
||||
Args:
|
||||
settings: Optional settings override
|
||||
"""
|
||||
self.settings = settings or get_settings()
|
||||
self.base_path = self.settings.workspace_base_path
|
||||
self._ensure_base_path()
|
||||
|
||||
def _ensure_base_path(self) -> None:
|
||||
"""Ensure the base workspace directory exists."""
|
||||
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_workspace_path(self, project_id: str) -> Path:
|
||||
"""Get the path for a project workspace with path traversal protection."""
|
||||
# Sanitize project ID for filesystem
|
||||
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_id)
|
||||
|
||||
# Reject reserved names
|
||||
reserved_names = {".", "..", "con", "prn", "aux", "nul"}
|
||||
if safe_id.lower() in reserved_names:
|
||||
raise ValueError(f"Invalid project ID: reserved name '{project_id}'")
|
||||
|
||||
# Construct path and verify it's within base_path (prevent path traversal)
|
||||
workspace_path = (self.base_path / safe_id).resolve()
|
||||
base_resolved = self.base_path.resolve()
|
||||
|
||||
if not workspace_path.is_relative_to(base_resolved):
|
||||
raise ValueError(
|
||||
f"Invalid project ID: path traversal detected '{project_id}'"
|
||||
)
|
||||
|
||||
return workspace_path
|
||||
|
||||
def _get_lock_path(self, project_id: str) -> Path:
|
||||
"""Get the lock file path for a workspace."""
|
||||
return self._get_workspace_path(project_id) / ".lock"
|
||||
|
||||
def _get_metadata_path(self, project_id: str) -> Path:
|
||||
"""Get the metadata file path for a workspace."""
|
||||
return self._get_workspace_path(project_id) / WORKSPACE_METADATA_FILE
|
||||
|
||||
async def get_workspace(self, project_id: str) -> WorkspaceInfo | None:
|
||||
"""
|
||||
Get workspace info for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
|
||||
Returns:
|
||||
WorkspaceInfo or None if not found
|
||||
"""
|
||||
workspace_path = self._get_workspace_path(project_id)
|
||||
|
||||
if not workspace_path.exists():
|
||||
return None
|
||||
|
||||
# Load metadata
|
||||
metadata = await self._load_metadata(project_id)
|
||||
|
||||
# Calculate size
|
||||
size_bytes = await self._calculate_size(workspace_path)
|
||||
|
||||
# Check lock status
|
||||
lock_holder = None
|
||||
lock_expires = None
|
||||
if metadata:
|
||||
lock_holder = metadata.get("lock_holder")
|
||||
if metadata.get("lock_expires"):
|
||||
lock_expires = datetime.fromisoformat(metadata["lock_expires"])
|
||||
# Clear expired locks
|
||||
if lock_expires < datetime.now(UTC):
|
||||
lock_holder = None
|
||||
lock_expires = None
|
||||
|
||||
# Determine state
|
||||
state = WorkspaceState.READY
|
||||
if lock_holder:
|
||||
state = WorkspaceState.LOCKED
|
||||
|
||||
# Check if stale
|
||||
last_accessed = datetime.now(UTC)
|
||||
if metadata and metadata.get("last_accessed"):
|
||||
last_accessed = datetime.fromisoformat(metadata["last_accessed"])
|
||||
stale_threshold = datetime.now(UTC) - timedelta(
|
||||
days=self.settings.workspace_stale_days
|
||||
)
|
||||
if last_accessed < stale_threshold:
|
||||
state = WorkspaceState.STALE
|
||||
|
||||
return WorkspaceInfo(
|
||||
project_id=project_id,
|
||||
path=str(workspace_path),
|
||||
state=state,
|
||||
repo_url=metadata.get("repo_url") if metadata else None,
|
||||
current_branch=metadata.get("current_branch") if metadata else None,
|
||||
last_accessed=last_accessed,
|
||||
size_bytes=size_bytes,
|
||||
lock_holder=lock_holder,
|
||||
lock_expires=lock_expires,
|
||||
)
|
||||
|
||||
async def create_workspace(
|
||||
self,
|
||||
project_id: str,
|
||||
repo_url: str | None = None,
|
||||
) -> WorkspaceInfo:
|
||||
"""
|
||||
Create or get a workspace for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
repo_url: Optional repository URL
|
||||
|
||||
Returns:
|
||||
WorkspaceInfo for the workspace
|
||||
"""
|
||||
workspace_path = self._get_workspace_path(project_id)
|
||||
|
||||
if workspace_path.exists():
|
||||
# Workspace already exists, update metadata
|
||||
await self._update_metadata(project_id, repo_url=repo_url)
|
||||
workspace = await self.get_workspace(project_id)
|
||||
if workspace:
|
||||
return workspace
|
||||
|
||||
# Create workspace directory
|
||||
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create initial metadata
|
||||
metadata = {
|
||||
"project_id": project_id,
|
||||
"repo_url": repo_url,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"last_accessed": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
await self._save_metadata(project_id, metadata)
|
||||
|
||||
return WorkspaceInfo(
|
||||
project_id=project_id,
|
||||
path=str(workspace_path),
|
||||
state=WorkspaceState.INITIALIZING,
|
||||
repo_url=repo_url,
|
||||
last_accessed=datetime.now(UTC),
|
||||
size_bytes=0,
|
||||
)
|
||||
|
||||
async def delete_workspace(self, project_id: str, force: bool = False) -> bool:
|
||||
"""
|
||||
Delete a workspace.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
force: Force delete even if locked
|
||||
|
||||
Returns:
|
||||
True if deleted
|
||||
"""
|
||||
workspace_path = self._get_workspace_path(project_id)
|
||||
|
||||
if not workspace_path.exists():
|
||||
return True
|
||||
|
||||
# Check lock
|
||||
if not force:
|
||||
workspace = await self.get_workspace(project_id)
|
||||
if workspace and workspace.state == WorkspaceState.LOCKED:
|
||||
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||
|
||||
try:
|
||||
# Use shutil.rmtree for robust deletion
|
||||
shutil.rmtree(workspace_path)
|
||||
logger.info(f"Deleted workspace for project: {project_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete workspace {project_id}: {e}")
|
||||
return False
|
||||
|
||||
async def lock_workspace(
|
||||
self,
|
||||
project_id: str,
|
||||
holder: str,
|
||||
timeout: int | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire a lock on a workspace.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
holder: Lock holder identifier (agent_id)
|
||||
timeout: Lock timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if lock acquired
|
||||
|
||||
Raises:
|
||||
WorkspaceNotFoundError: If workspace doesn't exist
|
||||
WorkspaceLockedError: If already locked by another
|
||||
"""
|
||||
workspace = await self.get_workspace(project_id)
|
||||
|
||||
if workspace is None:
|
||||
raise WorkspaceNotFoundError(project_id)
|
||||
|
||||
# Check if already locked by someone else
|
||||
if workspace.state == WorkspaceState.LOCKED and workspace.lock_holder != holder:
|
||||
# Check if lock expired
|
||||
if workspace.lock_expires and workspace.lock_expires > datetime.now(UTC):
|
||||
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||
|
||||
# Calculate lock expiry
|
||||
lock_timeout = timeout or self.settings.workspace_lock_timeout
|
||||
lock_expires = datetime.now(UTC) + timedelta(seconds=lock_timeout)
|
||||
|
||||
# Update metadata with lock info
|
||||
await self._update_metadata(
|
||||
project_id,
|
||||
lock_holder=holder,
|
||||
lock_expires=lock_expires.isoformat(),
|
||||
)
|
||||
|
||||
logger.info(f"Workspace {project_id} locked by {holder}")
|
||||
return True
|
||||
|
||||
async def unlock_workspace(
|
||||
self,
|
||||
project_id: str,
|
||||
holder: str,
|
||||
force: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Release a lock on a workspace.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
holder: Lock holder identifier
|
||||
force: Force unlock regardless of holder
|
||||
|
||||
Returns:
|
||||
True if unlocked
|
||||
"""
|
||||
workspace = await self.get_workspace(project_id)
|
||||
|
||||
if workspace is None:
|
||||
raise WorkspaceNotFoundError(project_id)
|
||||
|
||||
# Verify holder
|
||||
if not force and workspace.lock_holder and workspace.lock_holder != holder:
|
||||
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||
|
||||
# Clear lock
|
||||
await self._update_metadata(
|
||||
project_id,
|
||||
lock_holder=None,
|
||||
lock_expires=None,
|
||||
)
|
||||
|
||||
logger.info(f"Workspace {project_id} unlocked by {holder}")
|
||||
return True
|
||||
|
||||
async def touch_workspace(self, project_id: str) -> None:
|
||||
"""
|
||||
Update last accessed time for a workspace.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
"""
|
||||
await self._update_metadata(
|
||||
project_id,
|
||||
last_accessed=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
async def update_workspace_branch(
|
||||
self,
|
||||
project_id: str,
|
||||
branch: str,
|
||||
) -> None:
|
||||
"""
|
||||
Update the current branch in workspace metadata.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
branch: Current branch name
|
||||
"""
|
||||
await self._update_metadata(
|
||||
project_id,
|
||||
current_branch=branch,
|
||||
last_accessed=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
||||
async def check_size_limit(self, project_id: str) -> bool:
|
||||
"""
|
||||
Check if workspace exceeds size limit.
|
||||
|
||||
Args:
|
||||
project_id: Project identifier
|
||||
|
||||
Returns:
|
||||
True if within limits
|
||||
|
||||
Raises:
|
||||
WorkspaceSizeExceededError: If size exceeds limit
|
||||
"""
|
||||
workspace_path = self._get_workspace_path(project_id)
|
||||
|
||||
if not workspace_path.exists():
|
||||
return True
|
||||
|
||||
size_bytes = await self._calculate_size(workspace_path)
|
||||
size_gb = size_bytes / (1024**3)
|
||||
max_size_gb = self.settings.workspace_max_size_gb
|
||||
|
||||
if size_gb > max_size_gb:
|
||||
raise WorkspaceSizeExceededError(project_id, size_gb, max_size_gb)
|
||||
|
||||
return True
|
||||
|
||||
async def list_workspaces(
|
||||
self,
|
||||
include_stale: bool = False,
|
||||
) -> list[WorkspaceInfo]:
|
||||
"""
|
||||
List all workspaces.
|
||||
|
||||
Args:
|
||||
include_stale: Include stale workspaces
|
||||
|
||||
Returns:
|
||||
List of WorkspaceInfo
|
||||
"""
|
||||
workspaces: list[WorkspaceInfo] = []
|
||||
|
||||
if not self.base_path.exists():
|
||||
return workspaces
|
||||
|
||||
for entry in self.base_path.iterdir():
|
||||
if entry.is_dir() and not entry.name.startswith("."):
|
||||
# Extract project_id from directory name
|
||||
workspace = await self.get_workspace(entry.name)
|
||||
if workspace:
|
||||
if not include_stale and workspace.state == WorkspaceState.STALE:
|
||||
continue
|
||||
workspaces.append(workspace)
|
||||
|
||||
return workspaces
|
||||
|
||||
async def cleanup_stale_workspaces(self) -> int:
|
||||
"""
|
||||
Clean up stale workspaces.
|
||||
|
||||
Returns:
|
||||
Number of workspaces cleaned up
|
||||
"""
|
||||
cleaned = 0
|
||||
workspaces = await self.list_workspaces(include_stale=True)
|
||||
|
||||
for workspace in workspaces:
|
||||
if workspace.state == WorkspaceState.STALE:
|
||||
try:
|
||||
await self.delete_workspace(workspace.project_id, force=True)
|
||||
cleaned += 1
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to cleanup stale workspace {workspace.project_id}: {e}"
|
||||
)
|
||||
|
||||
if cleaned > 0:
|
||||
logger.info(f"Cleaned up {cleaned} stale workspaces")
|
||||
|
||||
return cleaned
|
||||
|
||||
async def get_total_size(self) -> int:
|
||||
"""
|
||||
Get total size of all workspaces.
|
||||
|
||||
Returns:
|
||||
Total size in bytes
|
||||
"""
|
||||
return await self._calculate_size(self.base_path)
|
||||
|
||||
# Private methods
|
||||
|
||||
async def _load_metadata(self, project_id: str) -> dict[str, Any] | None:
|
||||
"""Load workspace metadata from file."""
|
||||
metadata_path = self._get_metadata_path(project_id)
|
||||
|
||||
if not metadata_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
async with aiofiles.open(metadata_path) as f:
|
||||
content = await f.read()
|
||||
return json.loads(content)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load metadata for {project_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _save_metadata(
|
||||
self,
|
||||
project_id: str,
|
||||
metadata: dict[str, Any],
|
||||
) -> None:
|
||||
"""Save workspace metadata to file."""
|
||||
metadata_path = self._get_metadata_path(project_id)
|
||||
|
||||
# Ensure parent directory exists
|
||||
metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
async with aiofiles.open(metadata_path, "w") as f:
|
||||
await f.write(json.dumps(metadata, indent=2))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save metadata for {project_id}: {e}")
|
||||
|
||||
async def _update_metadata(
|
||||
self,
|
||||
project_id: str,
|
||||
**updates: Any,
|
||||
) -> None:
|
||||
"""Update specific fields in workspace metadata."""
|
||||
metadata = await self._load_metadata(project_id) or {}
|
||||
|
||||
# Handle None values (to clear fields)
|
||||
for key, value in updates.items():
|
||||
if value is None:
|
||||
metadata.pop(key, None)
|
||||
else:
|
||||
metadata[key] = value
|
||||
|
||||
await self._save_metadata(project_id, metadata)
|
||||
|
||||
async def _calculate_size(self, path: Path) -> int:
|
||||
"""Calculate total size of a directory."""
|
||||
|
||||
def _calc_size() -> int:
|
||||
total = 0
|
||||
try:
|
||||
for entry in path.rglob("*"):
|
||||
if entry.is_file():
|
||||
try:
|
||||
total += entry.stat().st_size
|
||||
except OSError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return total
|
||||
|
||||
# Run in executor for async compatibility
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, _calc_size)
|
||||
|
||||
|
||||
class WorkspaceLock:
|
||||
"""
|
||||
Context manager for workspace locking.
|
||||
|
||||
Provides automatic locking/unlocking with proper cleanup.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: WorkspaceManager,
|
||||
project_id: str,
|
||||
holder: str,
|
||||
timeout: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize workspace lock.
|
||||
|
||||
Args:
|
||||
manager: WorkspaceManager instance
|
||||
project_id: Project identifier
|
||||
holder: Lock holder identifier
|
||||
timeout: Lock timeout in seconds
|
||||
"""
|
||||
self.manager = manager
|
||||
self.project_id = project_id
|
||||
self.holder = holder
|
||||
self.timeout = timeout
|
||||
self._acquired = False
|
||||
|
||||
async def __aenter__(self) -> "WorkspaceLock":
|
||||
"""Acquire lock on enter."""
|
||||
await self.manager.lock_workspace(
|
||||
self.project_id,
|
||||
self.holder,
|
||||
self.timeout,
|
||||
)
|
||||
self._acquired = True
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Release lock on exit."""
|
||||
if self._acquired:
|
||||
try:
|
||||
await self.manager.unlock_workspace(
|
||||
self.project_id,
|
||||
self.holder,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to release lock for {self.project_id}: {e}")
|
||||
|
||||
|
||||
class FileLockManager:
|
||||
"""
|
||||
File-based locking for single-instance deployments.
|
||||
|
||||
Uses filelock for local locking when Redis is not available.
|
||||
"""
|
||||
|
||||
def __init__(self, lock_dir: Path) -> None:
|
||||
"""
|
||||
Initialize file lock manager.
|
||||
|
||||
Args:
|
||||
lock_dir: Directory for lock files
|
||||
"""
|
||||
self.lock_dir = lock_dir
|
||||
self.lock_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._locks: dict[str, FileLock] = {}
|
||||
|
||||
def _get_lock(self, key: str) -> FileLock:
|
||||
"""Get or create a file lock for a key."""
|
||||
if key not in self._locks:
|
||||
lock_path = self.lock_dir / f"{key}.lock"
|
||||
self._locks[key] = FileLock(lock_path)
|
||||
return self._locks[key]
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
key: str,
|
||||
timeout: float = 10.0,
|
||||
) -> bool:
|
||||
"""
|
||||
Acquire a lock.
|
||||
|
||||
Args:
|
||||
key: Lock key
|
||||
timeout: Timeout in seconds
|
||||
|
||||
Returns:
|
||||
True if acquired
|
||||
"""
|
||||
lock = self._get_lock(key)
|
||||
try:
|
||||
lock.acquire(timeout=timeout)
|
||||
return True
|
||||
except Timeout:
|
||||
return False
|
||||
|
||||
def release(self, key: str) -> bool:
|
||||
"""
|
||||
Release a lock.
|
||||
|
||||
Args:
|
||||
key: Lock key
|
||||
|
||||
Returns:
|
||||
True if released
|
||||
"""
|
||||
if key in self._locks:
|
||||
try:
|
||||
self._locks[key].release()
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def is_locked(self, key: str) -> bool:
|
||||
"""Check if a key is locked."""
|
||||
lock = self._get_lock(key)
|
||||
return lock.is_locked
|
||||
@@ -1,4 +1,8 @@
|
||||
.PHONY: help install install-dev lint lint-fix format type-check test test-cov validate clean run
|
||||
.PHONY: help install install-dev lint lint-fix format format-check type-check test test-cov validate clean run
|
||||
|
||||
# Ensure commands in this project don't inherit an external Python virtualenv
|
||||
# (prevents uv warnings about mismatched VIRTUAL_ENV when running from repo root)
|
||||
unexport VIRTUAL_ENV
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@@ -12,6 +16,7 @@ help:
|
||||
@echo " make lint - Run Ruff linter"
|
||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||
@echo " make format - Format code with Ruff"
|
||||
@echo " make format-check - Check if code is formatted"
|
||||
@echo " make type-check - Run mypy type checker"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@@ -19,7 +24,7 @@ help:
|
||||
@echo " make test-cov - Run pytest with coverage"
|
||||
@echo ""
|
||||
@echo "All-in-one:"
|
||||
@echo " make validate - Run lint, type-check, and tests"
|
||||
@echo " make validate - Run all checks (lint + format + types)"
|
||||
@echo ""
|
||||
@echo "Running:"
|
||||
@echo " make run - Run the server locally"
|
||||
@@ -49,6 +54,10 @@ format:
|
||||
@echo "Formatting code..."
|
||||
@uv run ruff format .
|
||||
|
||||
format-check:
|
||||
@echo "Checking code formatting..."
|
||||
@uv run ruff format --check .
|
||||
|
||||
type-check:
|
||||
@echo "Running mypy..."
|
||||
@uv run mypy . --ignore-missing-imports
|
||||
@@ -62,8 +71,9 @@ test-cov:
|
||||
@echo "Running tests with coverage..."
|
||||
@uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
|
||||
|
||||
|
||||
# All-in-one validation
|
||||
validate: lint type-check test
|
||||
validate: lint format-check type-check
|
||||
@echo "All validations passed!"
|
||||
|
||||
# Running
|
||||
|
||||
@@ -184,7 +184,12 @@ class ChunkerFactory:
|
||||
if file_type:
|
||||
if file_type == FileType.MARKDOWN:
|
||||
return self._get_markdown_chunker()
|
||||
elif file_type in (FileType.TEXT, FileType.JSON, FileType.YAML, FileType.TOML):
|
||||
elif file_type in (
|
||||
FileType.TEXT,
|
||||
FileType.JSON,
|
||||
FileType.YAML,
|
||||
FileType.TOML,
|
||||
):
|
||||
return self._get_text_chunker()
|
||||
else:
|
||||
# Code files
|
||||
@@ -193,7 +198,9 @@ class ChunkerFactory:
|
||||
# Default to text chunker
|
||||
return self._get_text_chunker()
|
||||
|
||||
def get_chunker_for_path(self, source_path: str) -> tuple[BaseChunker, FileType | None]:
|
||||
def get_chunker_for_path(
|
||||
self, source_path: str
|
||||
) -> tuple[BaseChunker, FileType | None]:
|
||||
"""
|
||||
Get chunker based on file path extension.
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ class CodeChunker(BaseChunker):
|
||||
for struct_type, pattern in patterns.items():
|
||||
for match in pattern.finditer(content):
|
||||
# Convert character position to line number
|
||||
line_num = content[:match.start()].count("\n")
|
||||
line_num = content[: match.start()].count("\n")
|
||||
boundaries.append((line_num, struct_type))
|
||||
|
||||
if not boundaries:
|
||||
|
||||
@@ -69,9 +69,7 @@ class MarkdownChunker(BaseChunker):
|
||||
|
||||
if not sections:
|
||||
# No headings, chunk as plain text
|
||||
return self._chunk_text_block(
|
||||
content, source_path, file_type, metadata, []
|
||||
)
|
||||
return self._chunk_text_block(content, source_path, file_type, metadata, [])
|
||||
|
||||
chunks: list[Chunk] = []
|
||||
heading_stack: list[tuple[int, str]] = [] # (level, text)
|
||||
@@ -292,7 +290,10 @@ class MarkdownChunker(BaseChunker):
|
||||
)
|
||||
|
||||
# Overlap: include last paragraph if it fits
|
||||
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
|
||||
if (
|
||||
current_content
|
||||
and self.count_tokens(current_content[-1]) <= self.chunk_overlap
|
||||
):
|
||||
current_content = [current_content[-1]]
|
||||
current_tokens = self.count_tokens(current_content[-1])
|
||||
else:
|
||||
@@ -341,12 +342,14 @@ class MarkdownChunker(BaseChunker):
|
||||
# Start of code block - save previous paragraph
|
||||
if current_para and any(p.strip() for p in current_para):
|
||||
para_content = "\n".join(current_para)
|
||||
paragraphs.append({
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": i - 1,
|
||||
})
|
||||
paragraphs.append(
|
||||
{
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": i - 1,
|
||||
}
|
||||
)
|
||||
current_para = [line]
|
||||
para_start = i
|
||||
in_code_block = True
|
||||
@@ -360,12 +363,14 @@ class MarkdownChunker(BaseChunker):
|
||||
if not line.strip():
|
||||
if current_para and any(p.strip() for p in current_para):
|
||||
para_content = "\n".join(current_para)
|
||||
paragraphs.append({
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": i - 1,
|
||||
})
|
||||
paragraphs.append(
|
||||
{
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": i - 1,
|
||||
}
|
||||
)
|
||||
current_para = []
|
||||
para_start = i + 1
|
||||
else:
|
||||
@@ -376,12 +381,14 @@ class MarkdownChunker(BaseChunker):
|
||||
# Final paragraph
|
||||
if current_para and any(p.strip() for p in current_para):
|
||||
para_content = "\n".join(current_para)
|
||||
paragraphs.append({
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": len(lines) - 1,
|
||||
})
|
||||
paragraphs.append(
|
||||
{
|
||||
"content": para_content,
|
||||
"tokens": self.count_tokens(para_content),
|
||||
"start_line": para_start,
|
||||
"end_line": len(lines) - 1,
|
||||
}
|
||||
)
|
||||
|
||||
return paragraphs
|
||||
|
||||
@@ -448,7 +455,10 @@ class MarkdownChunker(BaseChunker):
|
||||
)
|
||||
|
||||
# Overlap with last sentence
|
||||
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
|
||||
if (
|
||||
current_content
|
||||
and self.count_tokens(current_content[-1]) <= self.chunk_overlap
|
||||
):
|
||||
current_content = [current_content[-1]]
|
||||
current_tokens = self.count_tokens(current_content[-1])
|
||||
else:
|
||||
|
||||
@@ -79,9 +79,7 @@ class TextChunker(BaseChunker):
|
||||
)
|
||||
|
||||
# Fall back to sentence-based chunking
|
||||
return self._chunk_by_sentences(
|
||||
content, source_path, file_type, metadata
|
||||
)
|
||||
return self._chunk_by_sentences(content, source_path, file_type, metadata)
|
||||
|
||||
def _split_paragraphs(self, content: str) -> list[dict[str, Any]]:
|
||||
"""Split content into paragraphs."""
|
||||
@@ -97,12 +95,14 @@ class TextChunker(BaseChunker):
|
||||
continue
|
||||
|
||||
para_lines = para.count("\n") + 1
|
||||
paragraphs.append({
|
||||
"content": para,
|
||||
"tokens": self.count_tokens(para),
|
||||
"start_line": line_num,
|
||||
"end_line": line_num + para_lines - 1,
|
||||
})
|
||||
paragraphs.append(
|
||||
{
|
||||
"content": para,
|
||||
"tokens": self.count_tokens(para),
|
||||
"start_line": line_num,
|
||||
"end_line": line_num + para_lines - 1,
|
||||
}
|
||||
)
|
||||
line_num += para_lines + 1 # +1 for blank line between paragraphs
|
||||
|
||||
return paragraphs
|
||||
@@ -172,7 +172,10 @@ class TextChunker(BaseChunker):
|
||||
|
||||
# Overlap: keep last paragraph if small enough
|
||||
overlap_para = None
|
||||
if current_paras and self.count_tokens(current_paras[-1]) <= self.chunk_overlap:
|
||||
if (
|
||||
current_paras
|
||||
and self.count_tokens(current_paras[-1]) <= self.chunk_overlap
|
||||
):
|
||||
overlap_para = current_paras[-1]
|
||||
|
||||
current_paras = [overlap_para] if overlap_para else []
|
||||
@@ -266,7 +269,10 @@ class TextChunker(BaseChunker):
|
||||
|
||||
# Overlap: keep last sentence if small enough
|
||||
overlap = None
|
||||
if current_sentences and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap:
|
||||
if (
|
||||
current_sentences
|
||||
and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap
|
||||
):
|
||||
overlap = current_sentences[-1]
|
||||
|
||||
current_sentences = [overlap] if overlap else []
|
||||
@@ -317,14 +323,10 @@ class TextChunker(BaseChunker):
|
||||
sentences = self._split_sentences(text)
|
||||
|
||||
if len(sentences) > 1:
|
||||
return self._chunk_by_sentences(
|
||||
text, source_path, file_type, metadata
|
||||
)
|
||||
return self._chunk_by_sentences(text, source_path, file_type, metadata)
|
||||
|
||||
# Fall back to word-based splitting
|
||||
return self._chunk_by_words(
|
||||
text, source_path, file_type, metadata, base_line
|
||||
)
|
||||
return self._chunk_by_words(text, source_path, file_type, metadata, base_line)
|
||||
|
||||
def _chunk_by_words(
|
||||
self,
|
||||
|
||||
@@ -328,14 +328,18 @@ class CollectionManager:
|
||||
"source_path": chunk.source_path or source_path,
|
||||
"start_line": chunk.start_line,
|
||||
"end_line": chunk.end_line,
|
||||
"file_type": effective_file_type.value if (effective_file_type := chunk.file_type or file_type) else None,
|
||||
"file_type": effective_file_type.value
|
||||
if (effective_file_type := chunk.file_type or file_type)
|
||||
else None,
|
||||
}
|
||||
embeddings_data.append((
|
||||
chunk.content,
|
||||
embedding,
|
||||
chunk.chunk_type,
|
||||
chunk_metadata,
|
||||
))
|
||||
embeddings_data.append(
|
||||
(
|
||||
chunk.content,
|
||||
embedding,
|
||||
chunk.chunk_type,
|
||||
chunk_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
# Atomically replace old embeddings with new ones
|
||||
_, chunk_ids = await self.database.replace_source_embeddings(
|
||||
|
||||
@@ -214,9 +214,7 @@ class EmbeddingGenerator:
|
||||
return cached
|
||||
|
||||
# Generate via LLM Gateway
|
||||
embeddings = await self._call_llm_gateway(
|
||||
[text], project_id, agent_id
|
||||
)
|
||||
embeddings = await self._call_llm_gateway([text], project_id, agent_id)
|
||||
|
||||
if not embeddings:
|
||||
raise EmbeddingGenerationError(
|
||||
@@ -277,9 +275,7 @@ class EmbeddingGenerator:
|
||||
|
||||
for i in range(0, len(texts_to_embed), batch_size):
|
||||
batch = texts_to_embed[i : i + batch_size]
|
||||
batch_embeddings = await self._call_llm_gateway(
|
||||
batch, project_id, agent_id
|
||||
)
|
||||
batch_embeddings = await self._call_llm_gateway(batch, project_id, agent_id)
|
||||
new_embeddings.extend(batch_embeddings)
|
||||
|
||||
# Validate dimensions
|
||||
|
||||
@@ -149,12 +149,8 @@ class IngestRequest(BaseModel):
|
||||
source_path: str | None = Field(
|
||||
default=None, description="Source file path for reference"
|
||||
)
|
||||
collection: str = Field(
|
||||
default="default", description="Collection to store in"
|
||||
)
|
||||
chunk_type: ChunkType = Field(
|
||||
default=ChunkType.TEXT, description="Type of content"
|
||||
)
|
||||
collection: str = Field(default="default", description="Collection to store in")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="Type of content")
|
||||
file_type: FileType | None = Field(
|
||||
default=None, description="File type for code chunking"
|
||||
)
|
||||
@@ -255,12 +251,8 @@ class DeleteRequest(BaseModel):
|
||||
|
||||
project_id: str = Field(..., description="Project ID for scoping")
|
||||
agent_id: str = Field(..., description="Agent ID making the request")
|
||||
source_path: str | None = Field(
|
||||
default=None, description="Delete by source path"
|
||||
)
|
||||
collection: str | None = Field(
|
||||
default=None, description="Delete entire collection"
|
||||
)
|
||||
source_path: str | None = Field(default=None, description="Delete by source path")
|
||||
collection: str | None = Field(default=None, description="Delete entire collection")
|
||||
chunk_ids: list[str] | None = Field(
|
||||
default=None, description="Delete specific chunks"
|
||||
)
|
||||
|
||||
@@ -145,8 +145,7 @@ class SearchEngine:
|
||||
|
||||
# Filter by threshold (keyword search scores are normalized)
|
||||
filtered = [
|
||||
(emb, score) for emb, score in results
|
||||
if score >= request.threshold
|
||||
(emb, score) for emb, score in results if score >= request.threshold
|
||||
]
|
||||
|
||||
return [
|
||||
@@ -204,10 +203,9 @@ class SearchEngine:
|
||||
)
|
||||
|
||||
# Filter by threshold and limit
|
||||
filtered = [
|
||||
result for result in fused
|
||||
if result.score >= request.threshold
|
||||
][:request.limit]
|
||||
filtered = [result for result in fused if result.score >= request.threshold][
|
||||
: request.limit
|
||||
]
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
@@ -93,6 +93,7 @@ def _validate_source_path(value: str | None) -> str | None:
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -213,7 +214,9 @@ async def health_check() -> dict[str, Any]:
|
||||
if response.status_code == 200:
|
||||
status["dependencies"]["llm_gateway"] = "connected"
|
||||
else:
|
||||
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})"
|
||||
status["dependencies"]["llm_gateway"] = (
|
||||
f"unhealthy (status {response.status_code})"
|
||||
)
|
||||
is_degraded = True
|
||||
else:
|
||||
status["dependencies"]["llm_gateway"] = "not initialized"
|
||||
@@ -328,7 +331,9 @@ def _get_tool_schema(func: Any) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None:
|
||||
def _register_tool(
|
||||
name: str, tool_or_func: Any, description: str | None = None
|
||||
) -> None:
|
||||
"""Register a tool in the registry.
|
||||
|
||||
Handles both raw functions and FastMCP FunctionTool objects.
|
||||
@@ -337,7 +342,11 @@ def _register_tool(name: str, tool_or_func: Any, description: str | None = None)
|
||||
if hasattr(tool_or_func, "fn"):
|
||||
func = tool_or_func.fn
|
||||
# Use FunctionTool's description if available
|
||||
if not description and hasattr(tool_or_func, "description") and tool_or_func.description:
|
||||
if (
|
||||
not description
|
||||
and hasattr(tool_or_func, "description")
|
||||
and tool_or_func.description
|
||||
):
|
||||
description = tool_or_func.description
|
||||
else:
|
||||
func = tool_or_func
|
||||
@@ -358,11 +367,13 @@ async def list_mcp_tools() -> dict[str, Any]:
|
||||
"""
|
||||
tools = []
|
||||
for name, info in _tool_registry.items():
|
||||
tools.append({
|
||||
"name": name,
|
||||
"description": info["description"],
|
||||
"inputSchema": info["schema"],
|
||||
})
|
||||
tools.append(
|
||||
{
|
||||
"name": name,
|
||||
"description": info["description"],
|
||||
"inputSchema": info["schema"],
|
||||
}
|
||||
)
|
||||
|
||||
return {"tools": tools}
|
||||
|
||||
@@ -410,7 +421,10 @@ async def mcp_rpc(request: Request) -> JSONResponse:
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"},
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": "Invalid Request: jsonrpc must be '2.0'",
|
||||
},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
@@ -420,7 +434,10 @@ async def mcp_rpc(request: Request) -> JSONResponse:
|
||||
status_code=400,
|
||||
content={
|
||||
"jsonrpc": "2.0",
|
||||
"error": {"code": -32600, "message": "Invalid Request: method is required"},
|
||||
"error": {
|
||||
"code": -32600,
|
||||
"message": "Invalid Request: method is required",
|
||||
},
|
||||
"id": request_id,
|
||||
},
|
||||
)
|
||||
@@ -528,11 +545,23 @@ async def search_knowledge(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if collection and (error := _validate_collection(collection)):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Parse search type
|
||||
try:
|
||||
@@ -644,13 +673,29 @@ async def ingest_content(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_collection(collection):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_source_path(source_path):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Validate content size to prevent DoS
|
||||
settings = get_settings()
|
||||
@@ -750,13 +795,29 @@ async def delete_content(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if collection and (error := _validate_collection(collection)):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_source_path(source_path):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
request = DeleteRequest(
|
||||
project_id=project_id,
|
||||
@@ -803,9 +864,17 @@ async def list_collections(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
|
||||
|
||||
@@ -856,11 +925,23 @@ async def get_collection_stats(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_collection(collection):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
|
||||
|
||||
@@ -874,8 +955,12 @@ async def get_collection_stats(
|
||||
"avg_chunk_size": stats.avg_chunk_size,
|
||||
"chunk_types": stats.chunk_types,
|
||||
"file_types": stats.file_types,
|
||||
"oldest_chunk": stats.oldest_chunk.isoformat() if stats.oldest_chunk else None,
|
||||
"newest_chunk": stats.newest_chunk.isoformat() if stats.newest_chunk else None,
|
||||
"oldest_chunk": stats.oldest_chunk.isoformat()
|
||||
if stats.oldest_chunk
|
||||
else None,
|
||||
"newest_chunk": stats.newest_chunk.isoformat()
|
||||
if stats.newest_chunk
|
||||
else None,
|
||||
}
|
||||
|
||||
except KnowledgeBaseError as e:
|
||||
@@ -925,13 +1010,29 @@ async def update_document(
|
||||
try:
|
||||
# Validate inputs
|
||||
if error := _validate_id(project_id, "project_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_id(agent_id, "agent_id"):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_collection(collection):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
if error := _validate_source_path(source_path):
|
||||
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
|
||||
return {
|
||||
"success": False,
|
||||
"error": error,
|
||||
"code": ErrorCode.INVALID_REQUEST.value,
|
||||
}
|
||||
|
||||
# Validate content size to prevent DoS
|
||||
settings = get_settings()
|
||||
|
||||
@@ -83,7 +83,9 @@ def mock_embeddings():
|
||||
return [0.1] * 1536
|
||||
|
||||
mock_emb.generate = AsyncMock(return_value=fake_embedding())
|
||||
mock_emb.generate_batch = AsyncMock(side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts])
|
||||
mock_emb.generate_batch = AsyncMock(
|
||||
side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts]
|
||||
)
|
||||
|
||||
return mock_emb
|
||||
|
||||
@@ -137,7 +139,7 @@ async def async_function() -> None:
|
||||
@pytest.fixture
|
||||
def sample_markdown():
|
||||
"""Sample Markdown content for chunking tests."""
|
||||
return '''# Project Documentation
|
||||
return """# Project Documentation
|
||||
|
||||
This is the main documentation for our project.
|
||||
|
||||
@@ -182,20 +184,20 @@ The search endpoint allows you to query the knowledge base.
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! Please see our contributing guide.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_text():
|
||||
"""Sample plain text for chunking tests."""
|
||||
return '''The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
|
||||
return """The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
|
||||
|
||||
Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability.
|
||||
|
||||
When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval.
|
||||
|
||||
The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out!
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tests for chunking module."""
|
||||
|
||||
|
||||
|
||||
class TestBaseChunker:
|
||||
"""Tests for base chunker functionality."""
|
||||
|
||||
@@ -149,7 +148,7 @@ class TestMarkdownChunker:
|
||||
"""Test that chunker respects heading hierarchy."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
|
||||
markdown = '''# Main Title
|
||||
markdown = """# Main Title
|
||||
|
||||
Introduction paragraph.
|
||||
|
||||
@@ -164,7 +163,7 @@ More detailed content.
|
||||
## Section Two
|
||||
|
||||
Content for section two.
|
||||
'''
|
||||
"""
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=200,
|
||||
@@ -188,7 +187,7 @@ Content for section two.
|
||||
"""Test handling of code blocks in markdown."""
|
||||
from chunking.markdown import MarkdownChunker
|
||||
|
||||
markdown = '''# Code Example
|
||||
markdown = """# Code Example
|
||||
|
||||
Here's some code:
|
||||
|
||||
@@ -198,7 +197,7 @@ def hello():
|
||||
```
|
||||
|
||||
End of example.
|
||||
'''
|
||||
"""
|
||||
|
||||
chunker = MarkdownChunker(
|
||||
chunk_size=500,
|
||||
@@ -256,12 +255,12 @@ class TestTextChunker:
|
||||
"""Test that chunker respects paragraph boundaries."""
|
||||
from chunking.text import TextChunker
|
||||
|
||||
text = '''First paragraph with some content.
|
||||
text = """First paragraph with some content.
|
||||
|
||||
Second paragraph with different content.
|
||||
|
||||
Third paragraph to test chunking behavior.
|
||||
'''
|
||||
"""
|
||||
|
||||
chunker = TextChunker(
|
||||
chunk_size=100,
|
||||
|
||||
@@ -67,10 +67,14 @@ class TestCollectionManager:
|
||||
assert result.embeddings_generated == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request):
|
||||
async def test_ingest_error_handling(
|
||||
self, collection_manager, sample_ingest_request
|
||||
):
|
||||
"""Test ingest error handling."""
|
||||
# Make embedding generation fail
|
||||
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error")
|
||||
collection_manager._embeddings.generate_batch.side_effect = Exception(
|
||||
"Embedding error"
|
||||
)
|
||||
|
||||
result = await collection_manager.ingest(sample_ingest_request)
|
||||
|
||||
@@ -182,7 +186,9 @@ class TestCollectionManager:
|
||||
)
|
||||
collection_manager._database.get_collection_stats.return_value = expected_stats
|
||||
|
||||
stats = await collection_manager.get_collection_stats("proj-123", "test-collection")
|
||||
stats = await collection_manager.get_collection_stats(
|
||||
"proj-123", "test-collection"
|
||||
)
|
||||
|
||||
assert stats.chunk_count == 100
|
||||
assert stats.unique_sources == 10
|
||||
|
||||
@@ -17,19 +17,15 @@ class TestEmbeddingGenerator:
|
||||
response.raise_for_status = MagicMock()
|
||||
response.json.return_value = {
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 1536]
|
||||
})
|
||||
}
|
||||
]
|
||||
"content": [{"text": json.dumps({"embeddings": [[0.1] * 1536]})}]
|
||||
}
|
||||
}
|
||||
return response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
|
||||
async def test_generate_single_embedding(
|
||||
self, settings, mock_redis, mock_http_response
|
||||
):
|
||||
"""Test generating a single embedding."""
|
||||
from embeddings import EmbeddingGenerator
|
||||
|
||||
@@ -67,9 +63,9 @@ class TestEmbeddingGenerator:
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
|
||||
})
|
||||
"text": json.dumps(
|
||||
{"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]}
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -166,9 +162,11 @@ class TestEmbeddingGenerator:
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"text": json.dumps({
|
||||
"embeddings": [[0.1] * 768] # Wrong dimension
|
||||
})
|
||||
"text": json.dumps(
|
||||
{
|
||||
"embeddings": [[0.1] * 768] # Wrong dimension
|
||||
}
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tests for exception classes."""
|
||||
|
||||
|
||||
|
||||
class TestErrorCode:
|
||||
"""Tests for ErrorCode enum."""
|
||||
|
||||
@@ -10,8 +9,13 @@ class TestErrorCode:
|
||||
from exceptions import ErrorCode
|
||||
|
||||
assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR"
|
||||
assert ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
|
||||
assert ErrorCode.EMBEDDING_GENERATION_ERROR.value == "KB_EMBEDDING_GENERATION_ERROR"
|
||||
assert (
|
||||
ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
|
||||
)
|
||||
assert (
|
||||
ErrorCode.EMBEDDING_GENERATION_ERROR.value
|
||||
== "KB_EMBEDDING_GENERATION_ERROR"
|
||||
)
|
||||
assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR"
|
||||
assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR"
|
||||
assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND"
|
||||
|
||||
@@ -59,7 +59,9 @@ class TestSearchEngine:
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_semantic_search(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test semantic search."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -74,7 +76,9 @@ class TestSearchEngine:
|
||||
search_engine._database.semantic_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_keyword_search(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test keyword search."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -88,7 +92,9 @@ class TestSearchEngine:
|
||||
search_engine._database.keyword_search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_hybrid_search(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test hybrid search."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -105,7 +111,9 @@ class TestSearchEngine:
|
||||
assert len(response.results) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_search_with_collection_filter(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test search with collection filter."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -120,7 +128,9 @@ class TestSearchEngine:
|
||||
assert call_args.kwargs["collection"] == "specific-collection"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_search_with_file_type_filter(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test search with file type filter."""
|
||||
from models import FileType, SearchType
|
||||
|
||||
@@ -135,7 +145,9 @@ class TestSearchEngine:
|
||||
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_search_respects_limit(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test that search respects result limit."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -148,7 +160,9 @@ class TestSearchEngine:
|
||||
assert len(response.results) <= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results):
|
||||
async def test_search_records_time(
|
||||
self, search_engine, sample_search_request, sample_db_results
|
||||
):
|
||||
"""Test that search records time."""
|
||||
from models import SearchType
|
||||
|
||||
@@ -203,13 +217,21 @@ class TestReciprocalRankFusion:
|
||||
from models import SearchResult
|
||||
|
||||
semantic = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
||||
SearchResult(
|
||||
id="a", content="A", score=0.9, chunk_type="code", collection="default"
|
||||
),
|
||||
SearchResult(
|
||||
id="b", content="B", score=0.8, chunk_type="code", collection="default"
|
||||
),
|
||||
]
|
||||
|
||||
keyword = [
|
||||
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"),
|
||||
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
||||
SearchResult(
|
||||
id="b", content="B", score=0.85, chunk_type="code", collection="default"
|
||||
),
|
||||
SearchResult(
|
||||
id="c", content="C", score=0.7, chunk_type="code", collection="default"
|
||||
),
|
||||
]
|
||||
|
||||
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
||||
@@ -230,19 +252,23 @@ class TestReciprocalRankFusion:
|
||||
|
||||
# Same results in same order
|
||||
results = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
SearchResult(
|
||||
id="a", content="A", score=0.9, chunk_type="code", collection="default"
|
||||
),
|
||||
]
|
||||
|
||||
# High semantic weight
|
||||
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
|
||||
results, [],
|
||||
results,
|
||||
[],
|
||||
semantic_weight=0.9,
|
||||
keyword_weight=0.1,
|
||||
)
|
||||
|
||||
# High keyword weight
|
||||
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
|
||||
[], results,
|
||||
[],
|
||||
results,
|
||||
semantic_weight=0.1,
|
||||
keyword_weight=0.9,
|
||||
)
|
||||
@@ -256,12 +282,18 @@ class TestReciprocalRankFusion:
|
||||
from models import SearchResult
|
||||
|
||||
semantic = [
|
||||
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
|
||||
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
|
||||
SearchResult(
|
||||
id="a", content="A", score=0.9, chunk_type="code", collection="default"
|
||||
),
|
||||
SearchResult(
|
||||
id="b", content="B", score=0.8, chunk_type="code", collection="default"
|
||||
),
|
||||
]
|
||||
|
||||
keyword = [
|
||||
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
|
||||
SearchResult(
|
||||
id="c", content="C", score=0.7, chunk_type="code", collection="default"
|
||||
),
|
||||
]
|
||||
|
||||
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
.PHONY: help install install-dev lint lint-fix format type-check test test-cov validate clean run
|
||||
.PHONY: help install install-dev lint lint-fix format format-check type-check test test-cov validate clean run
|
||||
|
||||
# Ensure commands in this project don't inherit an external Python virtualenv
|
||||
# (prevents uv warnings about mismatched VIRTUAL_ENV when running from repo root)
|
||||
unexport VIRTUAL_ENV
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@@ -12,6 +16,7 @@ help:
|
||||
@echo " make lint - Run Ruff linter"
|
||||
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||
@echo " make format - Format code with Ruff"
|
||||
@echo " make format-check - Check if code is formatted"
|
||||
@echo " make type-check - Run mypy type checker"
|
||||
@echo ""
|
||||
@echo "Testing:"
|
||||
@@ -19,7 +24,7 @@ help:
|
||||
@echo " make test-cov - Run pytest with coverage"
|
||||
@echo ""
|
||||
@echo "All-in-one:"
|
||||
@echo " make validate - Run lint, type-check, and tests"
|
||||
@echo " make validate - Run all checks (lint + format + types)"
|
||||
@echo ""
|
||||
@echo "Running:"
|
||||
@echo " make run - Run the server locally"
|
||||
@@ -49,6 +54,10 @@ format:
|
||||
@echo "Formatting code..."
|
||||
@uv run ruff format .
|
||||
|
||||
format-check:
|
||||
@echo "Checking code formatting..."
|
||||
@uv run ruff format --check .
|
||||
|
||||
type-check:
|
||||
@echo "Running mypy..."
|
||||
@uv run mypy . --ignore-missing-imports
|
||||
@@ -63,7 +72,7 @@ test-cov:
|
||||
@uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
|
||||
|
||||
# All-in-one validation
|
||||
validate: lint type-check test
|
||||
validate: lint format-check type-check
|
||||
@echo "All validations passed!"
|
||||
|
||||
# Running
|
||||
|
||||
@@ -111,7 +111,10 @@ class CircuitBreaker:
|
||||
if self._state == CircuitState.OPEN:
|
||||
time_in_open = time.time() - self._stats.state_changed_at
|
||||
# Double-check state after time calculation (for thread safety)
|
||||
if time_in_open >= self.recovery_timeout and self._state == CircuitState.OPEN:
|
||||
if (
|
||||
time_in_open >= self.recovery_timeout
|
||||
and self._state == CircuitState.OPEN
|
||||
):
|
||||
self._transition_to(CircuitState.HALF_OPEN)
|
||||
logger.info(
|
||||
f"Circuit {self.name} transitioned to HALF_OPEN "
|
||||
|
||||
Reference in New Issue
Block a user