33 Commits

Author SHA1 Message Date
Felipe Cardoso
0a624a94af **test(git-ops): add comprehensive tests for server and API tools**
- Introduced extensive test coverage for FastAPI endpoints, including health check, MCP tools, and JSON-RPC operations.
- Added tests for Git operations MCP tools, including cloning, status, branching, committing, and provider detection.
- Mocked dependencies and ensured reliable test isolation with unittest.mock and pytest fixtures.
- Validated error handling, workspace management, tool execution, and type conversion functions.
2026-01-07 09:17:32 +01:00
Felipe Cardoso
011b21bf0a refactor(tests): adjust formatting for consistency and readability
- Updated line breaks and indentation across test modules to enhance clarity and maintain consistent style.
- Applied changes to workspace, provider, server, and GitWrapper-related test cases. No functional changes introduced.
2026-01-07 09:17:26 +01:00
Felipe Cardoso
76d7de5334 **feat(git-ops): enhance MCP server with Git provider updates and SSRF protection**
- Added `mcp-git-ops` service to `docker-compose.dev.yml` with health checks and configurations.
- Integrated SSRF protection in repository URL validation for enhanced security.
- Expanded `pyproject.toml` mypy settings and adjusted code to meet stricter type checking.
- Improved workspace management and GitWrapper operations with error handling refinements.
- Updated input validation, branching, and repository operations to align with new error structure.
- Shut down thread pool executor gracefully during server cleanup.
2026-01-07 09:17:00 +01:00
Felipe Cardoso
1779239c07 feat(git-ops): add GitHub provider with auto-detection
Implements GitHub API provider following the same pattern as Gitea:
- Full PR operations (create, get, list, merge, update, close)
- Branch operations via API
- Comment and label management
- Reviewer request support
- Rate limit error handling

Server enhancements:
- Auto-detect provider from repository URL (github.com vs custom Gitea)
- Initialize GitHub provider when token is configured
- Health check includes both provider statuses
- Token selection based on repo URL for clone/push operations

Refs: #110

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 20:55:22 +01:00
Felipe Cardoso
9dfa76aa41 feat(mcp): implement Git Operations MCP server with Gitea provider
Implements the Git Operations MCP server (Issue #58) providing:

Core features:
- GitPython wrapper for local repository operations (clone, commit, push, pull, diff, log)
- Branch management (create, delete, list, checkout)
- Workspace isolation per project with file-based locking
- Gitea provider for remote PR operations

MCP Tools (17 registered):
- clone_repository, git_status, create_branch, list_branches
- checkout, commit, push, pull, diff, log
- create_pull_request, get_pull_request, list_pull_requests
- merge_pull_request, get_workspace, lock_workspace, unlock_workspace

Technical details:
- FastMCP + FastAPI with JSON-RPC 2.0 protocol
- pydantic-settings for configuration (env prefix: GIT_OPS_)
- Comprehensive error hierarchy with structured codes
- 131 tests passing with 67% coverage
- Async operations via ThreadPoolExecutor

Closes: #105, #106, #107, #108, #109

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 20:48:20 +01:00
Felipe Cardoso
4ad3d20cf2 chore(agents): update sort_order values for agent types to improve logical grouping 2026-01-06 18:43:29 +01:00
Felipe Cardoso
8623eb56f5 feat(agents): add sorting by sort_order and include category & display fields in agent actions
- Implemented sorting of agent types by `sort_order` in Agents page.
- Added support for category, icon, color, sort_order, typical_tasks, and collaboration_hints fields in agent creation and update actions.
2026-01-06 18:20:04 +01:00
Felipe Cardoso
3cb6c8d13b feat(agents): implement grid/list view toggle and enhance filters
- Added grid and list view modes to AgentTypeList with user preference management.
- Enhanced filtering with category selection alongside existing search and status filters.
- Updated AgentTypeDetail with category badges and improved layout.
- Added unit tests for grid/list views and category filtering in AgentTypeList.
- Introduced `@radix-ui/react-toggle-group` for view mode toggle in AgentTypeList.
2026-01-06 18:17:46 +01:00
Felipe Cardoso
8e16e2645e test(forms): add unit tests for FormTextarea and FormSelect components
- Add comprehensive test coverage for FormTextarea and FormSelect components to validate rendering, accessibility, props forwarding, error handling, and behavior.
- Introduced function-scoped fixtures in e2e tests to ensure test isolation and address event loop issues with pytest-asyncio and SQLAlchemy.
2026-01-06 17:54:49 +01:00
Felipe Cardoso
82c3a6ba47 chore(makefiles): add format-check target and unify formatting logic
- Introduced `format-check` for verification without modification in `llm-gateway` and `knowledge-base` Makefiles.
- Updated `validate` to include `format-check`.
- Added `format-all` to root Makefile for consistent formatting across all components.
- Unexported `VIRTUAL_ENV` to prevent virtual environment warnings.
2026-01-06 17:25:21 +01:00
Felipe Cardoso
b6c38cac88 refactor(llm-gateway): adjust if-condition formatting for thread safety check
Updated line breaks and indentation for improved readability in circuit state recovery logic, ensuring consistent style.
2026-01-06 17:20:49 +01:00
Felipe Cardoso
51404216ae refactor(knowledge-base mcp server): adjust formatting for consistency and readability
Improved code formatting, line breaks, and indentation across chunking logic and multiple test modules to enhance code clarity and maintain consistent style. No functional changes made.
2026-01-06 17:20:31 +01:00
Felipe Cardoso
3f23bc3db3 refactor(migrations): replace hardcoded database URL with configurable environment variable and update command syntax to use consistent quoting style 2026-01-06 17:19:28 +01:00
Felipe Cardoso
a0ec5fa2cc test(agents): add validation tests for category and display fields
Added comprehensive unit and API tests to validate AgentType category and display fields:
- Category validation for valid, null, and invalid values
- Icon, color, and sort_order field constraints
- Typical tasks and collaboration hints handling (stripping, removing empty strings, normalization)
- New API tests for field creation, filtering, updating, and grouping
2026-01-06 17:19:21 +01:00
Felipe Cardoso
f262d08be2 test(project-events): add tests for demo configuration defaults
Added unit test cases to verify that the `demo.enabled` field is properly initialized to `false` in configurations and mock overrides.
2026-01-06 17:08:35 +01:00
Felipe Cardoso
b3f371e0a3 test(agents): add tests for AgentTypeForm enhancements
Added unit tests to cover new AgentTypeForm features:
- Category & Display fields (category select, sort order, icon, color)
- Typical Tasks management (add, remove, and prevent duplicates)
- Collaboration Hints management (add, remove, lowercase, and prevent duplicates)

This ensures thorough validation of recent form updates.
2026-01-06 17:07:21 +01:00
Felipe Cardoso
93cc37224c feat(agents): add category and display fields to AgentTypeForm
Add new "Category & Display" card in Basic Info tab with:
- Category dropdown to select agent category
- Sort order input for display ordering
- Icon text input with Lucide icon name
- Color picker with hex input and visual color selector
- Typical tasks tag input for agent capabilities
- Collaboration hints tag input for agent relationships

Updates include:
- TAB_FIELD_MAPPING with new field mappings
- State and handlers for typical_tasks and collaboration_hints
- Fix tests to use getAllByRole for multiple Add buttons

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 16:21:28 +01:00
Felipe Cardoso
5717bffd63 feat(agents): add frontend types and validation for category fields
Frontend changes to support new AgentType category and display fields:

Types (agentTypes.ts):
- Add AgentTypeCategory union type with 8 categories
- Add CATEGORY_METADATA constant with labels, descriptions, colors
- Update all interfaces with new fields (category, icon, color, etc.)
- Add AgentTypeGroupedResponse type

Validation (agentType.ts):
- Add AGENT_TYPE_CATEGORIES constant with metadata
- Add AVAILABLE_ICONS constant for icon picker
- Add COLOR_PALETTE constant for color selection
- Update agentTypeFormSchema with new field validators
- Update defaultAgentTypeValues with new fields

Form updates:
- Transform function now maps category and display fields from API

Test updates:
- Add new fields to mock AgentTypeResponse objects

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 16:16:21 +01:00
Felipe Cardoso
9339ea30a1 feat(agents): add category and display fields to AgentType model
Add 6 new fields to AgentType for better organization and UI display:
- category: enum for grouping (development, design, quality, etc.)
- icon: Lucide icon identifier for UI
- color: hex color code for visual distinction
- sort_order: display ordering within categories
- typical_tasks: list of tasks the agent excels at
- collaboration_hints: agent slugs that work well together

Backend changes:
- Add AgentTypeCategory enum to enums.py
- Update AgentType model with 6 new columns and indexes
- Update schemas with validators for new fields
- Add category filter and /grouped endpoint to routes
- Update CRUD with get_grouped_by_category method
- Update seed data with categories for all 27 agents
- Add migration 0007

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 16:11:22 +01:00
Felipe Cardoso
79cb6bfd7b feat(agents): comprehensive agent types with rich personalities
Major revamp of agent types based on SOTA personality design research:
- Expanded from 6 to 27 specialized agent types
- Rich personality prompts following Anthropic and CrewAI best practices
- Each agent has structured prompt with Core Identity, Expertise,
  Principles, and Scenario Handling sections

Agent Categories:
- Core Development (8): Product Owner, PM, BA, Architect, Full Stack,
  Backend, Frontend, Mobile Engineers
- Design (2): UI/UX Designer, UX Researcher
- Quality & Operations (3): QA, DevOps, Security Engineers
- AI/ML (5): AI/ML Engineer, Researcher, CV, NLP, MLOps Engineers
- Data (2): Data Scientist, Data Engineer
- Leadership (2): Technical Lead, Scrum Master
- Domain Specialists (5): Financial, Healthcare, Scientific,
  Behavioral Psychology Experts, Technical Writer

Research applied:
- Anthropic Claude persona design guidelines
- CrewAI role/backstory/goal patterns
- Role prompting research on detailed vs generic personas
- Temperature tuning per agent type (0.2-0.7 based on role)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 14:25:13 +01:00
Felipe Cardoso
45025bb2f1 fix(forms): handle nullable fields in deepMergeWithDefaults
When default value is null but source has a value (e.g., description
field), the merge was discarding the source value because typeof null
!== typeof string. Now properly accepts source values for nullable fields.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 13:54:18 +01:00
Felipe Cardoso
3c6b14d2bf refactor(forms): extract reusable form utilities and components
- Add getFirstValidationError utility for nested FieldErrors extraction
- Add mergeWithDefaults utilities (deepMergeWithDefaults, type guards)
- Add useValidationErrorHandler hook for toast + tab navigation
- Add FormSelect component with Controller integration
- Add FormTextarea component with register integration
- Refactor AgentTypeForm to use new utilities
- Remove verbose debug logging (now handled by hook)
- Add comprehensive tests (53 new tests, 100 total)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 13:50:36 +01:00
Felipe Cardoso
6b21a6fadd debug(agents): add comprehensive logging to form submission
Adds console.log statements throughout the form submission flow:
- Form submit triggered
- Current form values
- Form state (isDirty, isValid, isSubmitting, errors)
- Validation pass/fail
- onSubmit call and completion

This will help diagnose why the save button appears to do nothing.
Check browser console for '[AgentTypeForm]' logs.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 11:56:54 +01:00
Felipe Cardoso
600657adc4 fix(agents): properly initialize form with API data defaults
Root cause: The demo data's model_params was missing `top_p`, but the
Zod schema required all three fields (temperature, max_tokens, top_p).
This caused silent validation failures when editing agent types.

Fixes:
1. Add getInitialValues() that ensures all required fields have defaults
2. Handle nested validation errors in handleFormError (e.g., model_params.top_p)
3. Add useEffect to reset form when agentType changes
4. Add console.error logging for debugging validation failures
5. Update demo data to include top_p in all agent types

The form now properly initializes with safe defaults for any missing
fields from the API response, preventing silent validation failures.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 11:54:45 +01:00
Felipe Cardoso
c9d0d079b3 fix(frontend): show validation errors when agent type form fails
When form validation fails (e.g., personality_prompt is empty), the form
would silently not submit. Now it shows a toast with the first error
and navigates to the tab containing the error field.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 11:29:01 +01:00
Felipe Cardoso
4c8f81368c fix(docker): add NEXT_PUBLIC_API_BASE_URL to frontend containers
When running in Docker, the frontend needs to use 'http://backend:8000'
as the backend URL for Next.js rewrites. This env var is set to use
the Docker service name for proper container-to-container communication.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 09:23:50 +01:00
Felipe Cardoso
efbe91ce14 fix(frontend): use configurable backend URL in Next.js rewrite
The rewrite was using 'http://backend:8000' which only resolves inside
Docker network. When running Next.js locally (npm run dev), the hostname
'backend' doesn't exist, causing ENOTFOUND errors.

Now uses NEXT_PUBLIC_API_BASE_URL env var with fallback to localhost:8000
for local development. In Docker, set NEXT_PUBLIC_API_BASE_URL=http://backend:8000.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 09:22:44 +01:00
Felipe Cardoso
5d646779c9 fix(frontend): preserve /api prefix in Next.js rewrite
The rewrite was incorrectly configured:
- Before: /api/:path* -> http://backend:8000/:path* (strips /api)
- After: /api/:path* -> http://backend:8000/api/:path* (preserves /api)

This was causing requests to /api/v1/agent-types to be sent to
http://backend:8000/v1/agent-types instead of the correct path.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 03:12:08 +01:00
Felipe Cardoso
5a4d93df26 feat(dashboard): use real API data and add 3 more demo projects
Dashboard changes:
- Update useDashboard hook to fetch real projects from API
- Calculate stats (active projects, agents, issues) from real data
- Keep pending approvals as mock (no backend endpoint yet)

Demo data additions:
- API Gateway Modernization project (active, complex)
- Customer Analytics Dashboard project (completed)
- DevOps Pipeline Automation project (active, complex)
- Added sprints, agent instances, and issues for each new project

Total demo data: 6 projects, 14 agents, 22 issues

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 03:10:10 +01:00
Felipe Cardoso
7ef217be39 feat(demo): tie all demo projects to admin user
- Update demo_data.json to use "__admin__" as owner_email for all projects
- Add admin user lookup in load_demo_data() with special "__admin__" key
- Remove notification_email from project settings (not a valid field)

This ensures demo projects are visible to the admin user when logged in.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 03:00:07 +01:00
Felipe Cardoso
20159c5865 fix(knowledge-base): ensure pgvector extension before pool creation
register_vector() requires the vector type to exist in PostgreSQL before
it can register the type codec. Move CREATE EXTENSION to a separate
_ensure_pgvector_extension() method that runs before pool creation.

This fixes the "unknown type: public.vector" error on fresh databases.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 02:55:02 +01:00
Felipe Cardoso
f9a72fcb34 fix(models): use enum values instead of names for PostgreSQL
Add values_callable to all enum columns so SQLAlchemy serializes using
the enum's .value (lowercase) instead of .name (uppercase). PostgreSQL
enum types defined in migrations use lowercase values.

Fixes: invalid input value for enum autonomy_level: "MILESTONE"

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 02:53:45 +01:00
Felipe Cardoso
fcb0a5f86a fix(models): add explicit enum names to match migration types
SQLAlchemy's Enum() auto-generates type names from Python class names
(e.g., AutonomyLevel -> autonomylevel), but migrations defined them
with underscores (e.g., autonomy_level). This mismatch caused:

  "type 'autonomylevel' does not exist"

Added explicit name parameters to all enum columns to match the
migration-defined type names:
- autonomy_level, project_status, project_complexity, client_mode
- agent_status, sprint_status
- issue_type, issue_status, issue_priority, sync_status

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 02:48:10 +01:00
96 changed files with 20391 additions and 549 deletions

View File

@@ -1,5 +1,5 @@
.PHONY: help dev dev-full prod down logs logs-dev clean clean-slate drop-db reset-db push-images deploy
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all
.PHONY: test test-backend test-mcp test-frontend test-all test-cov test-integration validate validate-all format-all
VERSION ?= latest
REGISTRY ?= ghcr.io/cardosofelipe/pragma-stack
@@ -22,6 +22,9 @@ help:
@echo " make test-cov - Run all tests with coverage reports"
@echo " make test-integration - Run MCP integration tests (requires running stack)"
@echo ""
@echo "Formatting:"
@echo " make format-all - Format code in backend + MCP servers + frontend"
@echo ""
@echo "Validation:"
@echo " make validate - Validate backend + MCP servers (lint, type-check, test)"
@echo " make validate-all - Validate everything including frontend"
@@ -44,6 +47,7 @@ help:
@echo " cd backend && make help - Backend-specific commands"
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
@echo " cd mcp-servers/git-ops && make - Git Operations commands"
@echo " cd frontend && npm run - Frontend-specific commands"
# ============================================================================
@@ -135,6 +139,9 @@ test-mcp:
@echo ""
@echo "=== Knowledge Base ==="
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
@echo ""
@echo "=== Git Operations ==="
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v
test-frontend:
@echo "Running frontend tests..."
@@ -155,12 +162,37 @@ test-cov:
@echo ""
@echo "=== Knowledge Base Coverage ==="
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
@echo ""
@echo "=== Git Operations Coverage ==="
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing
test-integration:
@echo "Running MCP integration tests..."
@echo "Note: Requires running stack (make dev first)"
@cd backend && RUN_INTEGRATION_TESTS=true IS_TEST=True uv run pytest tests/integration/ -v
# ============================================================================
# Formatting
# ============================================================================
format-all:
@echo "Formatting backend..."
@cd backend && make format
@echo ""
@echo "Formatting LLM Gateway..."
@cd mcp-servers/llm-gateway && make format
@echo ""
@echo "Formatting Knowledge Base..."
@cd mcp-servers/knowledge-base && make format
@echo ""
@echo "Formatting Git Operations..."
@cd mcp-servers/git-ops && make format
@echo ""
@echo "Formatting frontend..."
@cd frontend && npm run format
@echo ""
@echo "All code formatted!"
# ============================================================================
# Validation (lint + type-check + test)
# ============================================================================
@@ -175,6 +207,9 @@ validate:
@echo "Validating Knowledge Base..."
@cd mcp-servers/knowledge-base && make validate
@echo ""
@echo "Validating Git Operations..."
@cd mcp-servers/git-ops && make validate
@echo ""
@echo "All validations passed!"
validate-all: validate

View File

@@ -0,0 +1,90 @@
"""Add category and display fields to agent_types table
Revision ID: 0007
Revises: 0006
Create Date: 2026-01-06
This migration adds:
- category: String(50) for grouping agents by role type
- icon: String(50) for Lucide icon identifier
- color: String(7) for hex color code
- sort_order: Integer for display ordering within categories
- typical_tasks: JSONB list of tasks this agent excels at
- collaboration_hints: JSONB list of agent slugs that work well together
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "0007"
down_revision: str | None = "0006"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
"""Add category and display fields to agent_types table."""
# Add new columns
op.add_column(
"agent_types",
sa.Column("category", sa.String(length=50), nullable=True),
)
op.add_column(
"agent_types",
sa.Column("icon", sa.String(length=50), nullable=True, server_default="bot"),
)
op.add_column(
"agent_types",
sa.Column(
"color", sa.String(length=7), nullable=True, server_default="#3B82F6"
),
)
op.add_column(
"agent_types",
sa.Column("sort_order", sa.Integer(), nullable=False, server_default="0"),
)
op.add_column(
"agent_types",
sa.Column(
"typical_tasks",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
)
op.add_column(
"agent_types",
sa.Column(
"collaboration_hints",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default="[]",
),
)
# Add indexes for category and sort_order
op.create_index("ix_agent_types_category", "agent_types", ["category"])
op.create_index("ix_agent_types_sort_order", "agent_types", ["sort_order"])
op.create_index(
"ix_agent_types_category_sort", "agent_types", ["category", "sort_order"]
)
def downgrade() -> None:
"""Remove category and display fields from agent_types table."""
# Drop indexes
op.drop_index("ix_agent_types_category_sort", table_name="agent_types")
op.drop_index("ix_agent_types_sort_order", table_name="agent_types")
op.drop_index("ix_agent_types_category", table_name="agent_types")
# Drop columns
op.drop_column("agent_types", "collaboration_hints")
op.drop_column("agent_types", "typical_tasks")
op.drop_column("agent_types", "sort_order")
op.drop_column("agent_types", "color")
op.drop_column("agent_types", "icon")
op.drop_column("agent_types", "category")

View File

@@ -81,6 +81,13 @@ def _build_agent_type_response(
mcp_servers=agent_type.mcp_servers,
tool_permissions=agent_type.tool_permissions,
is_active=agent_type.is_active,
# Category and display fields
category=agent_type.category,
icon=agent_type.icon,
color=agent_type.color,
sort_order=agent_type.sort_order,
typical_tasks=agent_type.typical_tasks or [],
collaboration_hints=agent_type.collaboration_hints or [],
created_at=agent_type.created_at,
updated_at=agent_type.updated_at,
instance_count=instance_count,
@@ -300,6 +307,7 @@ async def list_agent_types(
request: Request,
pagination: PaginationParams = Depends(),
is_active: bool = Query(True, description="Filter by active status"),
category: str | None = Query(None, description="Filter by category"),
search: str | None = Query(None, description="Search by name, slug, description"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
@@ -314,6 +322,7 @@ async def list_agent_types(
request: FastAPI request object
pagination: Pagination parameters (page, limit)
is_active: Filter by active status (default: True)
category: Filter by category (e.g., "development", "design")
search: Optional search term for name, slug, description
current_user: Authenticated user
db: Database session
@@ -328,6 +337,7 @@ async def list_agent_types(
skip=pagination.offset,
limit=pagination.limit,
is_active=is_active,
category=category,
search=search,
)
@@ -354,6 +364,51 @@ async def list_agent_types(
raise
@router.get(
"/grouped",
response_model=dict[str, list[AgentTypeResponse]],
summary="List Agent Types Grouped by Category",
description="Get all agent types organized by category",
operation_id="list_agent_types_grouped",
)
@limiter.limit(f"{60 * RATE_MULTIPLIER}/minute")
async def list_agent_types_grouped(
request: Request,
is_active: bool = Query(True, description="Filter by active status"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Any:
"""
Get agent types grouped by category.
Returns a dictionary where keys are category names and values
are lists of agent types, sorted by sort_order within each category.
Args:
request: FastAPI request object
is_active: Filter by active status (default: True)
current_user: Authenticated user
db: Database session
Returns:
Dictionary mapping category to list of agent types
"""
try:
grouped = await agent_type_crud.get_grouped_by_category(db, is_active=is_active)
# Transform to response objects
result: dict[str, list[AgentTypeResponse]] = {}
for category, types in grouped.items():
result[category] = [
_build_agent_type_response(t, instance_count=0) for t in types
]
return result
except Exception as e:
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
raise
@router.get(
"/{agent_type_id}",
response_model=AgentTypeResponse,

View File

@@ -43,6 +43,13 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
mcp_servers=obj_in.mcp_servers,
tool_permissions=obj_in.tool_permissions,
is_active=obj_in.is_active,
# Category and display fields
category=obj_in.category.value if obj_in.category else None,
icon=obj_in.icon,
color=obj_in.color,
sort_order=obj_in.sort_order,
typical_tasks=obj_in.typical_tasks,
collaboration_hints=obj_in.collaboration_hints,
)
db.add(db_obj)
await db.commit()
@@ -68,6 +75,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
category: str | None = None,
search: str | None = None,
sort_by: str = "created_at",
sort_order: str = "desc",
@@ -85,6 +93,9 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
if is_active is not None:
query = query.where(AgentType.is_active == is_active)
if category:
query = query.where(AgentType.category == category)
if search:
search_filter = or_(
AgentType.name.ilike(f"%{search}%"),
@@ -162,6 +173,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
skip: int = 0,
limit: int = 100,
is_active: bool | None = None,
category: str | None = None,
search: str | None = None,
) -> tuple[list[dict[str, Any]], int]:
"""
@@ -177,6 +189,7 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
skip=skip,
limit=limit,
is_active=is_active,
category=category,
search=search,
)
@@ -260,6 +273,44 @@ class CRUDAgentType(CRUDBase[AgentType, AgentTypeCreate, AgentTypeUpdate]):
)
raise
async def get_grouped_by_category(
self,
db: AsyncSession,
*,
is_active: bool = True,
) -> dict[str, list[AgentType]]:
"""
Get agent types grouped by category, sorted by sort_order within each group.
Args:
db: Database session
is_active: Filter by active status (default: True)
Returns:
Dictionary mapping category to list of agent types
"""
try:
query = (
select(AgentType)
.where(AgentType.is_active == is_active)
.order_by(AgentType.category, AgentType.sort_order, AgentType.name)
)
result = await db.execute(query)
agent_types = list(result.scalars().all())
# Group by category
grouped: dict[str, list[AgentType]] = {}
for at in agent_types:
cat: str = str(at.category) if at.category else "uncategorized"
if cat not in grouped:
grouped[cat] = []
grouped[cat].append(at)
return grouped
except Exception as e:
logger.error(f"Error getting grouped agent types: {e!s}", exc_info=True)
raise
# Create a singleton instance for use across the application
agent_type = CRUDAgentType(AgentType)

View File

@@ -149,6 +149,13 @@ async def load_default_agent_types(session: AsyncSession) -> None:
mcp_servers=agent_type_data.get("mcp_servers", []),
tool_permissions=agent_type_data.get("tool_permissions", {}),
is_active=agent_type_data.get("is_active", True),
# Category and display fields
category=agent_type_data.get("category"),
icon=agent_type_data.get("icon", "bot"),
color=agent_type_data.get("color", "#3B82F6"),
sort_order=agent_type_data.get("sort_order", 0),
typical_tasks=agent_type_data.get("typical_tasks", []),
collaboration_hints=agent_type_data.get("collaboration_hints", []),
)
await agent_type_crud.create(session, obj_in=agent_type_in)
@@ -267,6 +274,15 @@ async def load_demo_data(session: AsyncSession) -> None:
await session.flush()
# Add admin user to map with special "__admin__" key
# This allows demo data to reference the admin user as owner
superuser_email = settings.FIRST_SUPERUSER_EMAIL or "admin@example.com"
admin_user = await user_crud.get_by_email(session, email=superuser_email)
if admin_user:
user_map["__admin__"] = admin_user
user_map[str(admin_user.email)] = admin_user
logger.debug(f"Added admin user to map: {admin_user.email}")
# ========================
# 3. Load Agent Types Map (for FK resolution)
# ========================

View File

@@ -62,7 +62,11 @@ class AgentInstance(Base, UUIDMixin, TimestampMixin):
# Status tracking
status: Column[AgentStatus] = Column(
Enum(AgentStatus),
Enum(
AgentStatus,
name="agent_status",
values_callable=lambda x: [e.value for e in x],
),
default=AgentStatus.IDLE,
nullable=False,
index=True,

View File

@@ -6,7 +6,7 @@ An AgentType is a template that defines the capabilities, personality,
and model configuration for agent instances.
"""
from sqlalchemy import Boolean, Column, Index, String, Text
from sqlalchemy import Boolean, Column, Index, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import relationship
@@ -56,6 +56,24 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
# Whether this agent type is available for new instances
is_active = Column(Boolean, default=True, nullable=False, index=True)
# Category for grouping agents (development, design, quality, etc.)
category = Column(String(50), nullable=True, index=True)
# Lucide icon identifier for UI display (e.g., "code", "palette", "shield")
icon = Column(String(50), nullable=True, default="bot")
# Hex color code for visual distinction (e.g., "#3B82F6")
color = Column(String(7), nullable=True, default="#3B82F6")
# Display ordering within category (lower = first)
sort_order = Column(Integer, nullable=False, default=0, index=True)
# List of typical tasks this agent excels at
typical_tasks = Column(JSONB, default=list, nullable=False)
# List of agent slugs that collaborate well with this type
collaboration_hints = Column(JSONB, default=list, nullable=False)
# Relationships
instances = relationship(
"AgentInstance",
@@ -66,6 +84,7 @@ class AgentType(Base, UUIDMixin, TimestampMixin):
__table_args__ = (
Index("ix_agent_types_slug_active", "slug", "is_active"),
Index("ix_agent_types_name_active", "name", "is_active"),
Index("ix_agent_types_category_sort", "category", "sort_order"),
)
def __repr__(self) -> str:

View File

@@ -167,3 +167,29 @@ class SprintStatus(str, PyEnum):
IN_REVIEW = "in_review"
COMPLETED = "completed"
CANCELLED = "cancelled"
class AgentTypeCategory(str, PyEnum):
"""
Category classification for agent types.
Used for grouping and filtering agents in the UI.
DEVELOPMENT: Product, project, and engineering roles
DESIGN: UI/UX and design research roles
QUALITY: QA and security engineering
OPERATIONS: DevOps and MLOps
AI_ML: Machine learning and AI specialists
DATA: Data science and engineering
LEADERSHIP: Technical leadership roles
DOMAIN_EXPERT: Industry and domain specialists
"""
DEVELOPMENT = "development"
DESIGN = "design"
QUALITY = "quality"
OPERATIONS = "operations"
AI_ML = "ai_ml"
DATA = "data"
LEADERSHIP = "leadership"
DOMAIN_EXPERT = "domain_expert"

View File

@@ -59,7 +59,9 @@ class Issue(Base, UUIDMixin, TimestampMixin):
# Issue type (Epic, Story, Task, Bug)
type: Column[IssueType] = Column(
Enum(IssueType),
Enum(
IssueType, name="issue_type", values_callable=lambda x: [e.value for e in x]
),
default=IssueType.TASK,
nullable=False,
index=True,
@@ -78,14 +80,22 @@ class Issue(Base, UUIDMixin, TimestampMixin):
# Status and priority
status: Column[IssueStatus] = Column(
Enum(IssueStatus),
Enum(
IssueStatus,
name="issue_status",
values_callable=lambda x: [e.value for e in x],
),
default=IssueStatus.OPEN,
nullable=False,
index=True,
)
priority: Column[IssuePriority] = Column(
Enum(IssuePriority),
Enum(
IssuePriority,
name="issue_priority",
values_callable=lambda x: [e.value for e in x],
),
default=IssuePriority.MEDIUM,
nullable=False,
index=True,
@@ -132,7 +142,11 @@ class Issue(Base, UUIDMixin, TimestampMixin):
# Sync status with external tracker
sync_status: Column[SyncStatus] = Column(
Enum(SyncStatus),
Enum(
SyncStatus,
name="sync_status",
values_callable=lambda x: [e.value for e in x],
),
default=SyncStatus.SYNCED,
nullable=False,
# Note: Index defined in __table_args__ as ix_issues_sync_status

View File

@@ -35,28 +35,44 @@ class Project(Base, UUIDMixin, TimestampMixin):
description = Column(Text, nullable=True)
autonomy_level: Column[AutonomyLevel] = Column(
Enum(AutonomyLevel),
Enum(
AutonomyLevel,
name="autonomy_level",
values_callable=lambda x: [e.value for e in x],
),
default=AutonomyLevel.MILESTONE,
nullable=False,
index=True,
)
status: Column[ProjectStatus] = Column(
Enum(ProjectStatus),
Enum(
ProjectStatus,
name="project_status",
values_callable=lambda x: [e.value for e in x],
),
default=ProjectStatus.ACTIVE,
nullable=False,
index=True,
)
complexity: Column[ProjectComplexity] = Column(
Enum(ProjectComplexity),
Enum(
ProjectComplexity,
name="project_complexity",
values_callable=lambda x: [e.value for e in x],
),
default=ProjectComplexity.MEDIUM,
nullable=False,
index=True,
)
client_mode: Column[ClientMode] = Column(
Enum(ClientMode),
Enum(
ClientMode,
name="client_mode",
values_callable=lambda x: [e.value for e in x],
),
default=ClientMode.AUTO,
nullable=False,
index=True,

View File

@@ -57,7 +57,11 @@ class Sprint(Base, UUIDMixin, TimestampMixin):
# Status
status: Column[SprintStatus] = Column(
Enum(SprintStatus),
Enum(
SprintStatus,
name="sprint_status",
values_callable=lambda x: [e.value for e in x],
),
default=SprintStatus.PLANNED,
nullable=False,
index=True,

View File

@@ -10,6 +10,8 @@ from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator
from app.models.syndarix.enums import AgentTypeCategory
class AgentTypeBase(BaseModel):
"""Base agent type schema with common fields."""
@@ -26,6 +28,14 @@ class AgentTypeBase(BaseModel):
tool_permissions: dict[str, Any] = Field(default_factory=dict)
is_active: bool = True
# Category and display fields
category: AgentTypeCategory | None = None
icon: str | None = Field(None, max_length=50)
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
sort_order: int = Field(default=0, ge=0, le=1000)
typical_tasks: list[str] = Field(default_factory=list)
collaboration_hints: list[str] = Field(default_factory=list)
@field_validator("slug")
@classmethod
def validate_slug(cls, v: str | None) -> str | None:
@@ -62,6 +72,18 @@ class AgentTypeBase(BaseModel):
"""Validate MCP server list."""
return [s.strip() for s in v if s.strip()]
@field_validator("typical_tasks")
@classmethod
def validate_typical_tasks(cls, v: list[str]) -> list[str]:
"""Validate and normalize typical tasks list."""
return [t.strip() for t in v if t.strip()]
@field_validator("collaboration_hints")
@classmethod
def validate_collaboration_hints(cls, v: list[str]) -> list[str]:
"""Validate and normalize collaboration hints (agent slugs)."""
return [h.strip().lower() for h in v if h.strip()]
class AgentTypeCreate(AgentTypeBase):
"""Schema for creating a new agent type."""
@@ -87,6 +109,14 @@ class AgentTypeUpdate(BaseModel):
tool_permissions: dict[str, Any] | None = None
is_active: bool | None = None
# Category and display fields (all optional for updates)
category: AgentTypeCategory | None = None
icon: str | None = Field(None, max_length=50)
color: str | None = Field(None, pattern=r"^#[0-9A-Fa-f]{6}$")
sort_order: int | None = Field(None, ge=0, le=1000)
typical_tasks: list[str] | None = None
collaboration_hints: list[str] | None = None
@field_validator("slug")
@classmethod
def validate_slug(cls, v: str | None) -> str | None:
@@ -119,6 +149,22 @@ class AgentTypeUpdate(BaseModel):
return v
return [e.strip().lower() for e in v if e.strip()]
@field_validator("typical_tasks")
@classmethod
def validate_typical_tasks(cls, v: list[str] | None) -> list[str] | None:
"""Validate and normalize typical tasks list."""
if v is None:
return v
return [t.strip() for t in v if t.strip()]
@field_validator("collaboration_hints")
@classmethod
def validate_collaboration_hints(cls, v: list[str] | None) -> list[str] | None:
"""Validate and normalize collaboration hints (agent slugs)."""
if v is None:
return v
return [h.strip().lower() for h in v if h.strip()]
class AgentTypeInDB(AgentTypeBase):
"""Schema for agent type in database."""

File diff suppressed because it is too large Load Diff

View File

@@ -368,21 +368,20 @@
"name": "E-Commerce Platform Redesign",
"slug": "ecommerce-redesign",
"description": "Complete redesign of the e-commerce platform with modern UX, improved checkout flow, and mobile-first approach.",
"owner_email": "demo@example.com",
"owner_email": "__admin__",
"autonomy_level": "milestone",
"status": "active",
"complexity": "complex",
"client_mode": "technical",
"settings": {
"mcp_servers": ["gitea", "knowledge-base"],
"notification_email": "demo@example.com"
"mcp_servers": ["gitea", "knowledge-base"]
}
},
{
"name": "Mobile Banking App",
"slug": "mobile-banking",
"description": "Secure mobile banking application with biometric authentication, transaction history, and real-time notifications.",
"owner_email": "alice@acme.com",
"owner_email": "__admin__",
"autonomy_level": "full_control",
"status": "active",
"complexity": "complex",
@@ -396,7 +395,7 @@
"name": "Internal HR Portal",
"slug": "hr-portal",
"description": "Employee self-service portal for leave requests, performance reviews, and document management.",
"owner_email": "carol@globex.com",
"owner_email": "__admin__",
"autonomy_level": "autonomous",
"status": "active",
"complexity": "medium",
@@ -404,6 +403,45 @@
"settings": {
"mcp_servers": ["gitea", "knowledge-base"]
}
},
{
"name": "API Gateway Modernization",
"slug": "api-gateway",
"description": "Migrate legacy REST API gateway to modern GraphQL-based architecture with improved caching and rate limiting.",
"owner_email": "__admin__",
"autonomy_level": "milestone",
"status": "active",
"complexity": "complex",
"client_mode": "technical",
"settings": {
"mcp_servers": ["gitea", "knowledge-base"]
}
},
{
"name": "Customer Analytics Dashboard",
"slug": "analytics-dashboard",
"description": "Real-time analytics dashboard for customer behavior insights, cohort analysis, and predictive modeling.",
"owner_email": "__admin__",
"autonomy_level": "autonomous",
"status": "completed",
"complexity": "medium",
"client_mode": "auto",
"settings": {
"mcp_servers": ["gitea", "knowledge-base"]
}
},
{
"name": "DevOps Pipeline Automation",
"slug": "devops-automation",
"description": "Automate CI/CD pipelines with AI-assisted deployments, rollback detection, and infrastructure as code.",
"owner_email": "__admin__",
"autonomy_level": "full_control",
"status": "active",
"complexity": "complex",
"client_mode": "technical",
"settings": {
"mcp_servers": ["gitea", "knowledge-base"]
}
}
],
"sprints": [
@@ -446,6 +484,56 @@
"end_date": "2026-01-20",
"status": "active",
"planned_points": 18
},
{
"project_slug": "api-gateway",
"name": "Sprint 1: GraphQL Schema",
"number": 1,
"goal": "Define GraphQL schema and implement core resolvers for existing REST endpoints.",
"start_date": "2025-12-23",
"end_date": "2026-01-06",
"status": "completed",
"planned_points": 21
},
{
"project_slug": "api-gateway",
"name": "Sprint 2: Caching Layer",
"number": 2,
"goal": "Implement Redis-based caching layer and query batching.",
"start_date": "2026-01-06",
"end_date": "2026-01-20",
"status": "active",
"planned_points": 26
},
{
"project_slug": "analytics-dashboard",
"name": "Sprint 1: Data Pipeline",
"number": 1,
"goal": "Set up data ingestion pipeline and real-time event processing.",
"start_date": "2025-11-15",
"end_date": "2025-11-29",
"status": "completed",
"planned_points": 18
},
{
"project_slug": "analytics-dashboard",
"name": "Sprint 2: Dashboard UI",
"number": 2,
"goal": "Build interactive dashboard with charts and filtering capabilities.",
"start_date": "2025-11-29",
"end_date": "2025-12-13",
"status": "completed",
"planned_points": 21
},
{
"project_slug": "devops-automation",
"name": "Sprint 1: Pipeline Templates",
"number": 1,
"goal": "Create reusable CI/CD pipeline templates for common deployment patterns.",
"start_date": "2026-01-06",
"end_date": "2026-01-20",
"status": "active",
"planned_points": 24
}
],
"agent_instances": [
@@ -501,6 +589,40 @@
"name": "Atlas",
"status": "working",
"current_task": "Building employee dashboard API"
},
{
"project_slug": "api-gateway",
"agent_type_slug": "solutions-architect",
"name": "Orion",
"status": "working",
"current_task": "Designing caching strategy for GraphQL queries"
},
{
"project_slug": "api-gateway",
"agent_type_slug": "senior-engineer",
"name": "Cleo",
"status": "working",
"current_task": "Implementing Redis cache invalidation"
},
{
"project_slug": "devops-automation",
"agent_type_slug": "devops-engineer",
"name": "Volt",
"status": "working",
"current_task": "Creating Terraform modules for AWS ECS"
},
{
"project_slug": "devops-automation",
"agent_type_slug": "senior-engineer",
"name": "Sage",
"status": "idle"
},
{
"project_slug": "devops-automation",
"agent_type_slug": "qa-engineer",
"name": "Echo",
"status": "waiting",
"current_task": "Waiting for pipeline templates to test"
}
],
"issues": [
@@ -639,6 +761,119 @@
"priority": "medium",
"labels": ["backend", "infrastructure", "storage"],
"story_points": 5
},
{
"project_slug": "api-gateway",
"sprint_number": 2,
"type": "story",
"title": "Implement Redis caching layer",
"body": "As an API consumer, I want responses to be cached for improved performance.\n\n## Requirements\n- Cache GraphQL query results\n- Configurable TTL per query type\n- Cache invalidation on mutations\n- Cache hit/miss metrics",
"status": "in_progress",
"priority": "critical",
"labels": ["backend", "performance", "redis"],
"story_points": 8,
"assigned_agent_name": "Cleo"
},
{
"project_slug": "api-gateway",
"sprint_number": 2,
"type": "task",
"title": "Set up query batching and deduplication",
"body": "Implement DataLoader pattern for:\n- Batching multiple queries into single database calls\n- Deduplicating identical queries within request scope\n- N+1 query prevention",
"status": "open",
"priority": "high",
"labels": ["backend", "performance", "graphql"],
"story_points": 5
},
{
"project_slug": "api-gateway",
"sprint_number": 2,
"type": "task",
"title": "Implement rate limiting middleware",
"body": "Add rate limiting to prevent API abuse:\n- Per-user rate limits\n- Per-IP fallback for anonymous requests\n- Sliding window algorithm\n- Custom limits per operation type",
"status": "open",
"priority": "high",
"labels": ["backend", "security", "middleware"],
"story_points": 5,
"assigned_agent_name": "Orion"
},
{
"project_slug": "api-gateway",
"sprint_number": 2,
"type": "bug",
"title": "Fix N+1 query in user resolver",
"body": "The user resolver is making separate database calls for each user's organization.\n\n## Steps to Reproduce\n1. Query users with organization field\n2. Check database logs\n3. Observe N+1 queries",
"status": "open",
"priority": "high",
"labels": ["bug", "performance", "graphql"],
"story_points": 3
},
{
"project_slug": "analytics-dashboard",
"sprint_number": 2,
"type": "story",
"title": "Build cohort analysis charts",
"body": "As a product manager, I want to analyze user cohorts over time.\n\n## Features\n- Weekly/monthly cohort grouping\n- Retention curve visualization\n- Cohort comparison view",
"status": "closed",
"priority": "high",
"labels": ["frontend", "charts", "analytics"],
"story_points": 8
},
{
"project_slug": "analytics-dashboard",
"sprint_number": 2,
"type": "task",
"title": "Implement real-time event streaming",
"body": "Set up WebSocket connection for live event updates:\n- Event type filtering\n- Buffering for high-volume periods\n- Reconnection handling",
"status": "closed",
"priority": "high",
"labels": ["backend", "websocket", "realtime"],
"story_points": 5
},
{
"project_slug": "devops-automation",
"sprint_number": 1,
"type": "epic",
"title": "CI/CD Pipeline Templates",
"body": "Create reusable pipeline templates for common deployment patterns.\n\n## Templates Needed\n- Node.js applications\n- Python applications\n- Docker-based deployments\n- Kubernetes deployments",
"status": "in_progress",
"priority": "critical",
"labels": ["infrastructure", "cicd", "templates"],
"story_points": null
},
{
"project_slug": "devops-automation",
"sprint_number": 1,
"type": "story",
"title": "Create Terraform modules for AWS ECS",
"body": "As a DevOps engineer, I want Terraform modules for ECS deployments.\n\n## Modules\n- ECS cluster configuration\n- Service and task definitions\n- Load balancer integration\n- Auto-scaling policies",
"status": "in_progress",
"priority": "high",
"labels": ["terraform", "aws", "ecs"],
"story_points": 8,
"assigned_agent_name": "Volt"
},
{
"project_slug": "devops-automation",
"sprint_number": 1,
"type": "task",
"title": "Set up Gitea Actions runners",
"body": "Configure self-hosted Gitea Actions runners:\n- Docker-in-Docker support\n- Caching for npm/pip\n- Secrets management\n- Resource limits",
"status": "open",
"priority": "high",
"labels": ["infrastructure", "gitea", "cicd"],
"story_points": 5
},
{
"project_slug": "devops-automation",
"sprint_number": 1,
"type": "task",
"title": "Implement rollback detection system",
"body": "AI-assisted rollback detection:\n- Monitor deployment health metrics\n- Automatic rollback triggers\n- Notification system\n- Post-rollback analysis",
"status": "open",
"priority": "medium",
"labels": ["ai", "monitoring", "automation"],
"story_points": 8
}
]
}

View File

@@ -26,6 +26,7 @@ Usage:
# Inside Docker (without --local flag):
python migrate.py auto "Add new field"
"""
import argparse
import os
import subprocess
@@ -44,13 +45,14 @@ def setup_database_url(use_local: bool) -> str:
# Override DATABASE_URL to use localhost instead of Docker hostname
local_url = os.environ.get(
"LOCAL_DATABASE_URL",
"postgresql://postgres:postgres@localhost:5432/app"
"postgresql://postgres:postgres@localhost:5432/syndarix",
)
os.environ["DATABASE_URL"] = local_url
return local_url
# Use the configured DATABASE_URL from environment/.env
from app.core.config import settings
return settings.database_url
@@ -61,6 +63,7 @@ def check_models():
try:
# Import all models through the models package
from app.models import __all__ as all_models
print(f"Found {len(all_models)} model(s):")
for model in all_models:
print(f" - {model}")
@@ -110,7 +113,9 @@ def generate_migration(message, rev_id=None, auto_rev_id=True, offline=False):
# Look for the revision ID, which is typically 12 hex characters
parts = line.split()
for part in parts:
if len(part) >= 12 and all(c in "0123456789abcdef" for c in part[:12]):
if len(part) >= 12 and all(
c in "0123456789abcdef" for c in part[:12]
):
revision = part[:12]
break
except Exception as e:
@@ -185,6 +190,7 @@ def check_database_connection():
db_url = os.environ.get("DATABASE_URL")
if not db_url:
from app.core.config import settings
db_url = settings.database_url
engine = create_engine(db_url)
@@ -270,8 +276,8 @@ def generate_offline_migration(message, rev_id):
content = f'''"""{message}
Revision ID: {rev_id}
Revises: {down_revision or ''}
Create Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')}
Revises: {down_revision or ""}
Create Date: {datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")}
"""
@@ -320,6 +326,7 @@ def reset_alembic_version():
db_url = os.environ.get("DATABASE_URL")
if not db_url:
from app.core.config import settings
db_url = settings.database_url
try:
@@ -338,82 +345,80 @@ def reset_alembic_version():
def main():
"""Main function"""
parser = argparse.ArgumentParser(
description='Database migration helper for Generative Models Arena'
description="Database migration helper for Generative Models Arena"
)
# Global options
parser.add_argument(
'--local', '-l',
action='store_true',
help='Use localhost instead of Docker hostname (for local development)'
"--local",
"-l",
action="store_true",
help="Use localhost instead of Docker hostname (for local development)",
)
subparsers = parser.add_subparsers(dest='command', help='Command to run')
subparsers = parser.add_subparsers(dest="command", help="Command to run")
# Generate command
generate_parser = subparsers.add_parser('generate', help='Generate a migration')
generate_parser.add_argument('message', help='Migration message')
generate_parser = subparsers.add_parser("generate", help="Generate a migration")
generate_parser.add_argument("message", help="Migration message")
generate_parser.add_argument(
'--rev-id',
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
"--rev-id", help="Custom revision ID (e.g., 0001, 0002 for sequential naming)"
)
generate_parser.add_argument(
'--offline',
action='store_true',
help='Generate empty migration template without database connection'
"--offline",
action="store_true",
help="Generate empty migration template without database connection",
)
# Apply command
apply_parser = subparsers.add_parser('apply', help='Apply migrations')
apply_parser.add_argument('--revision', help='Specific revision to apply to')
apply_parser = subparsers.add_parser("apply", help="Apply migrations")
apply_parser.add_argument("--revision", help="Specific revision to apply to")
# List command
subparsers.add_parser('list', help='List migrations')
subparsers.add_parser("list", help="List migrations")
# Current command
subparsers.add_parser('current', help='Show current revision')
subparsers.add_parser("current", help="Show current revision")
# Check command
subparsers.add_parser('check', help='Check database connection and models')
subparsers.add_parser("check", help="Check database connection and models")
# Next command (show next revision ID)
subparsers.add_parser('next', help='Show the next sequential revision ID')
subparsers.add_parser("next", help="Show the next sequential revision ID")
# Reset command (clear alembic_version table)
subparsers.add_parser(
'reset',
help='Reset alembic_version table (use after deleting all migrations)'
"reset", help="Reset alembic_version table (use after deleting all migrations)"
)
# Auto command (generate and apply)
auto_parser = subparsers.add_parser('auto', help='Generate and apply migration')
auto_parser.add_argument('message', help='Migration message')
auto_parser = subparsers.add_parser("auto", help="Generate and apply migration")
auto_parser.add_argument("message", help="Migration message")
auto_parser.add_argument(
'--rev-id',
help='Custom revision ID (e.g., 0001, 0002 for sequential naming)'
"--rev-id", help="Custom revision ID (e.g., 0001, 0002 for sequential naming)"
)
auto_parser.add_argument(
'--offline',
action='store_true',
help='Generate empty migration template without database connection'
"--offline",
action="store_true",
help="Generate empty migration template without database connection",
)
args = parser.parse_args()
# Commands that don't need database connection
if args.command == 'next':
if args.command == "next":
show_next_rev_id()
return
# Check if offline mode is requested
offline = getattr(args, 'offline', False)
offline = getattr(args, "offline", False)
# Offline generate doesn't need database or model check
if args.command == 'generate' and offline:
if args.command == "generate" and offline:
generate_migration(args.message, rev_id=args.rev_id, offline=True)
return
if args.command == 'auto' and offline:
if args.command == "auto" and offline:
generate_migration(args.message, rev_id=args.rev_id, offline=True)
print("\nOffline migration generated. Apply it later with:")
print(" python migrate.py --local apply")
@@ -423,27 +428,27 @@ def main():
db_url = setup_database_url(args.local)
print(f"Using database URL: {db_url}")
if args.command == 'generate':
if args.command == "generate":
check_models()
generate_migration(args.message, rev_id=args.rev_id)
elif args.command == 'apply':
elif args.command == "apply":
apply_migration(args.revision)
elif args.command == 'list':
elif args.command == "list":
list_migrations()
elif args.command == 'current':
elif args.command == "current":
show_current()
elif args.command == 'check':
elif args.command == "check":
check_database_connection()
check_models()
elif args.command == 'reset':
elif args.command == "reset":
reset_alembic_version()
elif args.command == 'auto':
elif args.command == "auto":
check_models()
revision = generate_migration(args.message, rev_id=args.rev_id)
if revision:

View File

@@ -745,3 +745,230 @@ class TestAgentTypeInstanceCount:
for agent_type in data["data"]:
assert "instance_count" in agent_type
assert isinstance(agent_type["instance_count"], int)
@pytest.mark.asyncio
class TestAgentTypeCategoryFields:
"""Tests for agent type category and display fields."""
async def test_create_agent_type_with_category_fields(
self, client, superuser_token
):
"""Test creating agent type with all category and display fields."""
unique_slug = f"category-type-{uuid.uuid4().hex[:8]}"
response = await client.post(
"/api/v1/agent-types",
json={
"name": "Categorized Agent Type",
"slug": unique_slug,
"description": "An agent type with category fields",
"expertise": ["python"],
"personality_prompt": "You are a helpful assistant.",
"primary_model": "claude-opus-4-5-20251101",
# Category and display fields
"category": "development",
"icon": "code",
"color": "#3B82F6",
"sort_order": 10,
"typical_tasks": ["Write code", "Review PRs"],
"collaboration_hints": ["backend-engineer", "qa-engineer"],
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["category"] == "development"
assert data["icon"] == "code"
assert data["color"] == "#3B82F6"
assert data["sort_order"] == 10
assert data["typical_tasks"] == ["Write code", "Review PRs"]
assert data["collaboration_hints"] == ["backend-engineer", "qa-engineer"]
async def test_create_agent_type_with_nullable_category(
self, client, superuser_token
):
"""Test creating agent type with null category."""
unique_slug = f"null-category-{uuid.uuid4().hex[:8]}"
response = await client.post(
"/api/v1/agent-types",
json={
"name": "Uncategorized Agent",
"slug": unique_slug,
"expertise": ["general"],
"personality_prompt": "You are a helpful assistant.",
"primary_model": "claude-opus-4-5-20251101",
"category": None,
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_201_CREATED
data = response.json()
assert data["category"] is None
async def test_create_agent_type_invalid_color_format(
self, client, superuser_token
):
"""Test that invalid color format is rejected."""
unique_slug = f"invalid-color-{uuid.uuid4().hex[:8]}"
response = await client.post(
"/api/v1/agent-types",
json={
"name": "Invalid Color Agent",
"slug": unique_slug,
"expertise": ["python"],
"personality_prompt": "You are a helpful assistant.",
"primary_model": "claude-opus-4-5-20251101",
"color": "not-a-hex-color",
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
async def test_create_agent_type_invalid_category(self, client, superuser_token):
"""Test that invalid category value is rejected."""
unique_slug = f"invalid-category-{uuid.uuid4().hex[:8]}"
response = await client.post(
"/api/v1/agent-types",
json={
"name": "Invalid Category Agent",
"slug": unique_slug,
"expertise": ["python"],
"personality_prompt": "You are a helpful assistant.",
"primary_model": "claude-opus-4-5-20251101",
"category": "not_a_valid_category",
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
async def test_update_agent_type_category_fields(
self, client, superuser_token, test_agent_type
):
"""Test updating category and display fields."""
agent_type_id = test_agent_type["id"]
response = await client.patch(
f"/api/v1/agent-types/{agent_type_id}",
json={
"category": "ai_ml",
"icon": "brain",
"color": "#8B5CF6",
"sort_order": 50,
"typical_tasks": ["Train models", "Analyze data"],
"collaboration_hints": ["data-scientist"],
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["category"] == "ai_ml"
assert data["icon"] == "brain"
assert data["color"] == "#8B5CF6"
assert data["sort_order"] == 50
assert data["typical_tasks"] == ["Train models", "Analyze data"]
assert data["collaboration_hints"] == ["data-scientist"]
@pytest.mark.asyncio
class TestAgentTypeCategoryFilter:
"""Tests for agent type category filtering."""
async def test_list_agent_types_filter_by_category(
self, client, superuser_token, user_token
):
"""Test filtering agent types by category."""
# Create agent types in different categories
for cat in ["development", "design"]:
unique_slug = f"filter-test-{cat}-{uuid.uuid4().hex[:8]}"
await client.post(
"/api/v1/agent-types",
json={
"name": f"Filter Test {cat.capitalize()}",
"slug": unique_slug,
"expertise": ["python"],
"personality_prompt": "Test prompt",
"primary_model": "claude-opus-4-5-20251101",
"category": cat,
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Filter by development category
response = await client.get(
"/api/v1/agent-types",
params={"category": "development"},
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
# All returned types should have development category
for agent_type in data["data"]:
assert agent_type["category"] == "development"
@pytest.mark.asyncio
class TestAgentTypeGroupedEndpoint:
"""Tests for the grouped by category endpoint."""
async def test_list_agent_types_grouped(self, client, superuser_token, user_token):
"""Test getting agent types grouped by category."""
# Create agent types in different categories
categories = ["development", "design", "quality"]
for cat in categories:
unique_slug = f"grouped-test-{cat}-{uuid.uuid4().hex[:8]}"
await client.post(
"/api/v1/agent-types",
json={
"name": f"Grouped Test {cat.capitalize()}",
"slug": unique_slug,
"expertise": ["python"],
"personality_prompt": "Test prompt",
"primary_model": "claude-opus-4-5-20251101",
"category": cat,
"sort_order": 10,
},
headers={"Authorization": f"Bearer {superuser_token}"},
)
# Get grouped agent types
response = await client.get(
"/api/v1/agent-types/grouped",
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
# Should be a dict with category keys
assert isinstance(data, dict)
# Check that at least one of our created categories exists
assert any(cat in data for cat in categories)
async def test_list_agent_types_grouped_filter_inactive(
self, client, superuser_token, user_token
):
"""Test grouped endpoint with is_active filter."""
response = await client.get(
"/api/v1/agent-types/grouped",
params={"is_active": False},
headers={"Authorization": f"Bearer {user_token}"},
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, dict)
async def test_list_agent_types_grouped_unauthenticated(self, client):
"""Test that unauthenticated users cannot access grouped endpoint."""
response = await client.get("/api/v1/agent-types/grouped")
assert response.status_code == status.HTTP_401_UNAUTHORIZED

View File

@@ -368,3 +368,9 @@ async def e2e_org_with_members(e2e_client, e2e_superuser):
"user_id": member_id,
},
}
# NOTE: Class-scoped fixtures for E2E tests were attempted but have fundamental
# issues with pytest-asyncio + SQLAlchemy/asyncpg event loop management.
# The function-scoped fixtures above provide proper test isolation.
# Performance optimization would require significant infrastructure changes.

View File

@@ -316,3 +316,325 @@ class TestAgentTypeJsonFields:
)
assert agent_type.fallback_models == models
class TestAgentTypeCategoryFieldsValidation:
"""Tests for AgentType category and display field validation."""
def test_valid_category_values(self):
"""Test that all valid category values are accepted."""
valid_categories = [
"development",
"design",
"quality",
"operations",
"ai_ml",
"data",
"leadership",
"domain_expert",
]
for category in valid_categories:
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
category=category,
)
assert agent_type.category.value == category
def test_category_null_allowed(self):
"""Test that null category is allowed."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
category=None,
)
assert agent_type.category is None
def test_invalid_category_rejected(self):
"""Test that invalid category values are rejected."""
with pytest.raises(ValidationError):
AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
category="invalid_category",
)
def test_valid_hex_color(self):
"""Test that valid hex colors are accepted."""
valid_colors = ["#3B82F6", "#EC4899", "#10B981", "#ffffff", "#000000"]
for color in valid_colors:
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
color=color,
)
assert agent_type.color == color
def test_invalid_hex_color_rejected(self):
"""Test that invalid hex colors are rejected."""
invalid_colors = [
"not-a-color",
"3B82F6", # Missing #
"#3B82F", # Too short
"#3B82F6A", # Too long
"#GGGGGG", # Invalid hex chars
"rgb(59, 130, 246)", # RGB format not supported
]
for color in invalid_colors:
with pytest.raises(ValidationError):
AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
color=color,
)
def test_color_null_allowed(self):
"""Test that null color is allowed."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
color=None,
)
assert agent_type.color is None
def test_sort_order_valid_range(self):
"""Test that valid sort_order values are accepted."""
for sort_order in [0, 1, 500, 1000]:
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
sort_order=sort_order,
)
assert agent_type.sort_order == sort_order
def test_sort_order_default_zero(self):
"""Test that sort_order defaults to 0."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
)
assert agent_type.sort_order == 0
def test_sort_order_negative_rejected(self):
"""Test that negative sort_order is rejected."""
with pytest.raises(ValidationError):
AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
sort_order=-1,
)
def test_sort_order_exceeds_max_rejected(self):
"""Test that sort_order > 1000 is rejected."""
with pytest.raises(ValidationError):
AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
sort_order=1001,
)
def test_icon_max_length(self):
"""Test that icon field respects max length."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
icon="x" * 50,
)
assert len(agent_type.icon) == 50
def test_icon_exceeds_max_length_rejected(self):
"""Test that icon exceeding max length is rejected."""
with pytest.raises(ValidationError):
AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
icon="x" * 51,
)
class TestAgentTypeTypicalTasksValidation:
"""Tests for typical_tasks field validation."""
def test_typical_tasks_list(self):
"""Test typical_tasks as a list."""
tasks = ["Write code", "Review PRs", "Debug issues"]
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
typical_tasks=tasks,
)
assert agent_type.typical_tasks == tasks
def test_typical_tasks_default_empty(self):
"""Test typical_tasks defaults to empty list."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
)
assert agent_type.typical_tasks == []
def test_typical_tasks_strips_whitespace(self):
"""Test that typical_tasks items are stripped."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
typical_tasks=[" Write code ", " Debug "],
)
assert agent_type.typical_tasks == ["Write code", "Debug"]
def test_typical_tasks_removes_empty_strings(self):
"""Test that empty strings are removed from typical_tasks."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
typical_tasks=["Write code", "", " ", "Debug"],
)
assert agent_type.typical_tasks == ["Write code", "Debug"]
class TestAgentTypeCollaborationHintsValidation:
"""Tests for collaboration_hints field validation."""
def test_collaboration_hints_list(self):
"""Test collaboration_hints as a list."""
hints = ["backend-engineer", "qa-engineer"]
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
collaboration_hints=hints,
)
assert agent_type.collaboration_hints == hints
def test_collaboration_hints_default_empty(self):
"""Test collaboration_hints defaults to empty list."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
)
assert agent_type.collaboration_hints == []
def test_collaboration_hints_normalized_lowercase(self):
"""Test that collaboration_hints are normalized to lowercase."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
collaboration_hints=["Backend-Engineer", "QA-ENGINEER"],
)
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
def test_collaboration_hints_strips_whitespace(self):
"""Test that collaboration_hints are stripped."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
collaboration_hints=[" backend-engineer ", " qa-engineer "],
)
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
def test_collaboration_hints_removes_empty_strings(self):
"""Test that empty strings are removed from collaboration_hints."""
agent_type = AgentTypeCreate(
name="Test Agent",
slug="test-agent",
personality_prompt="Test",
primary_model="claude-opus-4-5-20251101",
collaboration_hints=["backend-engineer", "", " ", "qa-engineer"],
)
assert agent_type.collaboration_hints == ["backend-engineer", "qa-engineer"]
class TestAgentTypeUpdateCategoryFields:
"""Tests for AgentTypeUpdate category and display fields."""
def test_update_category_field(self):
"""Test updating category field."""
update = AgentTypeUpdate(category="ai_ml")
assert update.category.value == "ai_ml"
def test_update_icon_field(self):
"""Test updating icon field."""
update = AgentTypeUpdate(icon="brain")
assert update.icon == "brain"
def test_update_color_field(self):
"""Test updating color field."""
update = AgentTypeUpdate(color="#8B5CF6")
assert update.color == "#8B5CF6"
def test_update_sort_order_field(self):
"""Test updating sort_order field."""
update = AgentTypeUpdate(sort_order=50)
assert update.sort_order == 50
def test_update_typical_tasks_field(self):
"""Test updating typical_tasks field."""
update = AgentTypeUpdate(typical_tasks=["New task"])
assert update.typical_tasks == ["New task"]
def test_update_typical_tasks_strips_whitespace(self):
"""Test that typical_tasks are stripped on update."""
update = AgentTypeUpdate(typical_tasks=[" New task "])
assert update.typical_tasks == ["New task"]
def test_update_collaboration_hints_field(self):
"""Test updating collaboration_hints field."""
update = AgentTypeUpdate(collaboration_hints=["new-collaborator"])
assert update.collaboration_hints == ["new-collaborator"]
def test_update_collaboration_hints_normalized(self):
"""Test that collaboration_hints are normalized on update."""
update = AgentTypeUpdate(collaboration_hints=[" New-Collaborator "])
assert update.collaboration_hints == ["new-collaborator"]
def test_update_invalid_color_rejected(self):
"""Test that invalid color is rejected on update."""
with pytest.raises(ValidationError):
AgentTypeUpdate(color="invalid")
def test_update_invalid_sort_order_rejected(self):
"""Test that invalid sort_order is rejected on update."""
with pytest.raises(ValidationError):
AgentTypeUpdate(sort_order=-1)

View File

@@ -42,6 +42,9 @@ class TestInitDb:
assert user.last_name == "User"
@pytest.mark.asyncio
@pytest.mark.skip(
reason="SQLite doesn't support UUID type binding - requires PostgreSQL"
)
async def test_init_db_returns_existing_superuser(
self, async_test_db, async_test_user
):

View File

@@ -288,6 +288,7 @@ services:
environment:
- NODE_ENV=production
- NEXT_PUBLIC_API_URL=${NEXT_PUBLIC_API_URL}
- NEXT_PUBLIC_API_BASE_URL=http://backend:8000
depends_on:
backend:
condition: service_healthy

View File

@@ -96,6 +96,38 @@ services:
- app-network
restart: unless-stopped
mcp-git-ops:
build:
context: ./mcp-servers/git-ops
dockerfile: Dockerfile
ports:
- "8003:8003"
env_file:
- .env
environment:
# GIT_OPS_ prefix required by pydantic-settings config
- GIT_OPS_HOST=0.0.0.0
- GIT_OPS_PORT=8003
- GIT_OPS_REDIS_URL=redis://redis:6379/3
- GIT_OPS_GITEA_BASE_URL=${GITEA_BASE_URL}
- GIT_OPS_GITEA_TOKEN=${GITEA_TOKEN}
- GIT_OPS_GITHUB_TOKEN=${GITHUB_TOKEN}
- ENVIRONMENT=development
volumes:
- git_workspaces_dev:/workspaces
depends_on:
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- app-network
restart: unless-stopped
backend:
build:
context: ./backend
@@ -119,6 +151,7 @@ services:
# MCP Server URLs
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
- GIT_OPS_URL=http://mcp-git-ops:8003
depends_on:
db:
condition: service_healthy
@@ -128,6 +161,8 @@ services:
condition: service_healthy
mcp-knowledge-base:
condition: service_healthy
mcp-git-ops:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 10s
@@ -155,6 +190,7 @@ services:
# MCP Server URLs (agents need access to MCP)
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
- GIT_OPS_URL=http://mcp-git-ops:8003
depends_on:
db:
condition: service_healthy
@@ -164,6 +200,8 @@ services:
condition: service_healthy
mcp-knowledge-base:
condition: service_healthy
mcp-git-ops:
condition: service_healthy
networks:
- app-network
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]
@@ -181,11 +219,14 @@ services:
- DATABASE_URL=${DATABASE_URL}
- REDIS_URL=redis://redis:6379/0
- CELERY_QUEUE=git
- GIT_OPS_URL=http://mcp-git-ops:8003
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
mcp-git-ops:
condition: service_healthy
networks:
- app-network
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "git", "-l", "info", "-c", "2"]
@@ -249,6 +290,7 @@ services:
environment:
- NODE_ENV=development
- NEXT_PUBLIC_API_URL=${NEXT_PUBLIC_API_URL}
- NEXT_PUBLIC_API_BASE_URL=http://backend:8000
depends_on:
backend:
condition: service_healthy
@@ -259,6 +301,7 @@ services:
volumes:
postgres_data_dev:
redis_data_dev:
git_workspaces_dev:
frontend_dev_modules:
frontend_dev_next:

View File

@@ -74,12 +74,14 @@ const nextConfig: NextConfig = {
];
},
// Ensure we can connect to the backend in Docker
// Proxy API requests to backend
// Use NEXT_PUBLIC_API_BASE_URL for the destination (defaults to localhost for local dev)
async rewrites() {
const backendUrl = process.env.NEXT_PUBLIC_API_BASE_URL || 'http://localhost:8000';
return [
{
source: '/api/:path*',
destination: 'http://backend:8000/:path*',
destination: `${backendUrl}/api/:path*`,
},
];
},

View File

@@ -21,6 +21,7 @@
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slot": "^1.2.4",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-toggle-group": "^1.1.11",
"@tanstack/react-query": "^5.90.5",
"@types/react-syntax-highlighter": "^15.5.13",
"axios": "^1.13.1",
@@ -4688,6 +4689,60 @@
}
}
},
"node_modules/@radix-ui/react-toggle": {
"version": "1.1.10",
"resolved": "https://registry.npmjs.org/@radix-ui/react-toggle/-/react-toggle-1.1.10.tgz",
"integrity": "sha512-lS1odchhFTeZv3xwHH31YPObmJn8gOg7Lq12inrr0+BH/l3Tsq32VfjqH1oh80ARM3mlkfMic15n0kg4sD1poQ==",
"license": "MIT",
"dependencies": {
"@radix-ui/primitive": "1.1.3",
"@radix-ui/react-primitive": "2.1.3",
"@radix-ui/react-use-controllable-state": "1.2.2"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-toggle-group": {
"version": "1.1.11",
"resolved": "https://registry.npmjs.org/@radix-ui/react-toggle-group/-/react-toggle-group-1.1.11.tgz",
"integrity": "sha512-5umnS0T8JQzQT6HbPyO7Hh9dgd82NmS36DQr+X/YJ9ctFNCiiQd6IJAYYZ33LUwm8M+taCz5t2ui29fHZc4Y6Q==",
"license": "MIT",
"dependencies": {
"@radix-ui/primitive": "1.1.3",
"@radix-ui/react-context": "1.1.2",
"@radix-ui/react-direction": "1.1.1",
"@radix-ui/react-primitive": "2.1.3",
"@radix-ui/react-roving-focus": "1.1.11",
"@radix-ui/react-toggle": "1.1.10",
"@radix-ui/react-use-controllable-state": "1.2.2"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-use-callback-ref": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.1.1.tgz",

View File

@@ -35,6 +35,7 @@
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-slot": "^1.2.4",
"@radix-ui/react-tabs": "^1.1.13",
"@radix-ui/react-toggle-group": "^1.1.11",
"@tanstack/react-query": "^5.90.5",
"@types/react-syntax-highlighter": "^15.5.13",
"axios": "^1.13.1",

View File

@@ -73,6 +73,13 @@ export default function AgentTypeDetailPage() {
mcp_servers: data.mcp_servers,
tool_permissions: data.tool_permissions,
is_active: data.is_active,
// Category and display fields
category: data.category,
icon: data.icon,
color: data.color,
sort_order: data.sort_order,
typical_tasks: data.typical_tasks,
collaboration_hints: data.collaboration_hints,
});
toast.success('Agent type created', {
description: `${result.name} has been created successfully`,
@@ -94,6 +101,13 @@ export default function AgentTypeDetailPage() {
mcp_servers: data.mcp_servers,
tool_permissions: data.tool_permissions,
is_active: data.is_active,
// Category and display fields
category: data.category,
icon: data.icon,
color: data.color,
sort_order: data.sort_order,
typical_tasks: data.typical_tasks,
collaboration_hints: data.collaboration_hints,
},
});
toast.success('Agent type updated', {

View File

@@ -1,8 +1,8 @@
/**
* Agent Types List Page
*
* Displays a list of agent types with search and filter functionality.
* Allows navigation to agent type detail and creation pages.
* Displays a list of agent types with search, status, and category filters.
* Supports grid and list view modes with user preference persistence.
*/
'use client';
@@ -10,9 +10,10 @@
import { useState, useCallback, useMemo } from 'react';
import { useRouter } from '@/lib/i18n/routing';
import { toast } from 'sonner';
import { AgentTypeList } from '@/components/agents';
import { AgentTypeList, type ViewMode } from '@/components/agents';
import { useAgentTypes } from '@/lib/api/hooks/useAgentTypes';
import { useDebounce } from '@/lib/hooks/useDebounce';
import type { AgentTypeCategory } from '@/lib/api/types/agentTypes';
export default function AgentTypesPage() {
const router = useRouter();
@@ -20,6 +21,8 @@ export default function AgentTypesPage() {
// Filter state
const [searchQuery, setSearchQuery] = useState('');
const [statusFilter, setStatusFilter] = useState('all');
const [categoryFilter, setCategoryFilter] = useState('all');
const [viewMode, setViewMode] = useState<ViewMode>('grid');
// Debounce search for API calls
const debouncedSearch = useDebounce(searchQuery, 300);
@@ -31,21 +34,25 @@ export default function AgentTypesPage() {
return undefined; // 'all' returns undefined to not filter
}, [statusFilter]);
// Determine category filter value
const categoryFilterValue = useMemo(() => {
if (categoryFilter === 'all') return undefined;
return categoryFilter as AgentTypeCategory;
}, [categoryFilter]);
// Fetch agent types
const { data, isLoading, error } = useAgentTypes({
search: debouncedSearch || undefined,
is_active: isActiveFilter,
category: categoryFilterValue,
page: 1,
limit: 50,
});
// Filter results client-side for 'all' status
// Get filtered and sorted agent types (sort by sort_order ascending - smaller first)
const filteredAgentTypes = useMemo(() => {
if (!data?.data) return [];
// When status is 'all', we need to fetch both and combine
// For now, the API returns based on is_active filter
return data.data;
return [...data.data].sort((a, b) => a.sort_order - b.sort_order);
}, [data?.data]);
// Handle navigation to agent type detail
@@ -71,6 +78,16 @@ export default function AgentTypesPage() {
setStatusFilter(status);
}, []);
// Handle category filter change
const handleCategoryFilterChange = useCallback((category: string) => {
setCategoryFilter(category);
}, []);
// Handle view mode change
const handleViewModeChange = useCallback((mode: ViewMode) => {
setViewMode(mode);
}, []);
// Show error toast if fetch fails
if (error) {
toast.error('Failed to load agent types', {
@@ -87,6 +104,10 @@ export default function AgentTypesPage() {
onSearchChange={handleSearchChange}
statusFilter={statusFilter}
onStatusFilterChange={handleStatusFilterChange}
categoryFilter={categoryFilter}
onCategoryFilterChange={handleCategoryFilterChange}
viewMode={viewMode}
onViewModeChange={handleViewModeChange}
onSelect={handleSelect}
onCreate={handleCreate}
/>

View File

@@ -2,7 +2,8 @@
* AgentTypeDetail Component
*
* Displays detailed information about a single agent type.
* Shows model configuration, permissions, personality, and instance stats.
* Features a hero header with icon/color, category, typical tasks,
* collaboration hints, model configuration, and instance stats.
*/
'use client';
@@ -36,8 +37,13 @@ import {
Cpu,
CheckCircle2,
AlertTriangle,
Sparkles,
Users,
Check,
} from 'lucide-react';
import type { AgentTypeResponse } from '@/lib/api/types/agentTypes';
import { DynamicIcon } from '@/components/ui/dynamic-icon';
import type { AgentTypeResponse, AgentTypeCategory } from '@/lib/api/types/agentTypes';
import { CATEGORY_METADATA } from '@/lib/api/types/agentTypes';
import { AVAILABLE_MCP_SERVERS } from '@/lib/validations/agentType';
interface AgentTypeDetailProps {
@@ -51,6 +57,30 @@ interface AgentTypeDetailProps {
className?: string;
}
/**
* Category badge with color
*/
function CategoryBadge({ category }: { category: AgentTypeCategory | null }) {
if (!category) return null;
const meta = CATEGORY_METADATA[category];
if (!meta) return null;
return (
<Badge
variant="outline"
className="font-medium"
style={{
borderColor: meta.color,
color: meta.color,
backgroundColor: `${meta.color}10`,
}}
>
{meta.label}
</Badge>
);
}
/**
* Status badge component for agent types
*/
@@ -81,11 +111,22 @@ function AgentTypeStatusBadge({ isActive }: { isActive: boolean }) {
function AgentTypeDetailSkeleton() {
return (
<div className="space-y-6">
<div className="flex items-center gap-4">
<Skeleton className="h-10 w-10" />
<div className="flex-1">
<Skeleton className="h-8 w-64" />
<Skeleton className="mt-2 h-4 w-48" />
{/* Hero skeleton */}
<div className="rounded-xl border p-6">
<div className="flex items-start gap-6">
<Skeleton className="h-20 w-20 rounded-xl" />
<div className="flex-1 space-y-3">
<Skeleton className="h-8 w-64" />
<Skeleton className="h-4 w-96" />
<div className="flex gap-2">
<Skeleton className="h-6 w-20" />
<Skeleton className="h-6 w-24" />
</div>
</div>
<div className="flex gap-2">
<Skeleton className="h-9 w-24" />
<Skeleton className="h-9 w-20" />
</div>
</div>
</div>
<div className="grid gap-6 lg:grid-cols-3">
@@ -161,57 +202,134 @@ export function AgentTypeDetail({
top_p?: number;
};
const agentColor = agentType.color || '#3B82F6';
return (
<div className={className}>
{/* Header */}
<div className="mb-6 flex items-center gap-4">
<Button variant="ghost" size="icon" onClick={onBack}>
<ArrowLeft className="h-4 w-4" />
<span className="sr-only">Go back</span>
</Button>
<div className="flex-1">
<div className="flex items-center gap-3">
<h1 className="text-3xl font-bold">{agentType.name}</h1>
<AgentTypeStatusBadge isActive={agentType.is_active} />
{/* Back button */}
<Button variant="ghost" size="sm" onClick={onBack} className="mb-4">
<ArrowLeft className="mr-2 h-4 w-4" />
Back to Agent Types
</Button>
{/* Hero Header */}
<div
className="mb-6 overflow-hidden rounded-xl border"
style={{
background: `linear-gradient(135deg, ${agentColor}08 0%, transparent 60%)`,
borderColor: `${agentColor}30`,
}}
>
<div
className="h-1.5 w-full"
style={{ background: `linear-gradient(90deg, ${agentColor}, ${agentColor}60)` }}
/>
<div className="p-6">
<div className="flex flex-col gap-6 md:flex-row md:items-start">
{/* Icon */}
<div
className="flex h-20 w-20 shrink-0 items-center justify-center rounded-xl"
style={{
backgroundColor: `${agentColor}15`,
boxShadow: `0 8px 32px ${agentColor}20`,
}}
>
<DynamicIcon
name={agentType.icon}
className="h-10 w-10"
style={{ color: agentColor }}
fallback="bot"
/>
</div>
{/* Info */}
<div className="flex-1 space-y-3">
<div>
<h1 className="text-3xl font-bold">{agentType.name}</h1>
<p className="mt-1 text-muted-foreground">
{agentType.description || 'No description provided'}
</p>
</div>
<div className="flex flex-wrap items-center gap-2">
<AgentTypeStatusBadge isActive={agentType.is_active} />
<CategoryBadge category={agentType.category} />
<span className="text-sm text-muted-foreground">
Last updated:{' '}
{new Date(agentType.updated_at).toLocaleDateString('en-US', {
year: 'numeric',
month: 'short',
day: 'numeric',
})}
</span>
</div>
</div>
{/* Actions */}
<div className="flex shrink-0 gap-2">
<Button variant="outline" size="sm" onClick={onDuplicate}>
<Copy className="mr-2 h-4 w-4" />
Duplicate
</Button>
<Button size="sm" onClick={onEdit}>
<Edit className="mr-2 h-4 w-4" />
Edit
</Button>
</div>
</div>
<p className="text-muted-foreground">
Last modified:{' '}
{new Date(agentType.updated_at).toLocaleDateString('en-US', {
year: 'numeric',
month: 'long',
day: 'numeric',
})}
</p>
</div>
<div className="flex gap-2">
<Button variant="outline" size="sm" onClick={onDuplicate}>
<Copy className="mr-2 h-4 w-4" />
Duplicate
</Button>
<Button size="sm" onClick={onEdit}>
<Edit className="mr-2 h-4 w-4" />
Edit
</Button>
</div>
</div>
<div className="grid gap-6 lg:grid-cols-3">
{/* Main Content */}
<div className="space-y-6 lg:col-span-2">
{/* Description Card */}
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<FileText className="h-5 w-5" />
Description
</CardTitle>
</CardHeader>
<CardContent>
<p className="text-muted-foreground">
{agentType.description || 'No description provided'}
</p>
</CardContent>
</Card>
{/* What This Agent Does Best */}
{agentType.typical_tasks.length > 0 && (
<Card className="border-primary/20 bg-gradient-to-br from-primary/5 to-transparent">
<CardHeader className="pb-3">
<CardTitle className="flex items-center gap-2 text-lg">
<Sparkles className="h-5 w-5 text-primary" />
What This Agent Does Best
</CardTitle>
</CardHeader>
<CardContent>
<ul className="space-y-2">
{agentType.typical_tasks.map((task, index) => (
<li key={index} className="flex items-start gap-2">
<Check
className="mt-0.5 h-4 w-4 shrink-0 text-primary"
style={{ color: agentColor }}
/>
<span className="text-sm">{task}</span>
</li>
))}
</ul>
</CardContent>
</Card>
)}
{/* Works Well With */}
{agentType.collaboration_hints.length > 0 && (
<Card>
<CardHeader className="pb-3">
<CardTitle className="flex items-center gap-2 text-lg">
<Users className="h-5 w-5" />
Works Well With
</CardTitle>
<CardDescription>
Agents that complement this type for effective collaboration
</CardDescription>
</CardHeader>
<CardContent>
<div className="flex flex-wrap gap-2">
{agentType.collaboration_hints.map((hint, index) => (
<Badge key={index} variant="secondary" className="text-sm">
{hint}
</Badge>
))}
</div>
</CardContent>
</Card>
)}
{/* Expertise Card */}
<Card>
@@ -355,7 +473,9 @@ export function AgentTypeDetail({
</CardHeader>
<CardContent>
<div className="text-center">
<p className="text-4xl font-bold text-primary">{agentType.instance_count}</p>
<p className="text-4xl font-bold" style={{ color: agentColor }}>
{agentType.instance_count}
</p>
<p className="text-sm text-muted-foreground">Active instances</p>
</div>
<Button variant="outline" className="mt-4 w-full" size="sm" disabled>
@@ -364,6 +484,36 @@ export function AgentTypeDetail({
</CardContent>
</Card>
{/* Agent Info */}
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2 text-lg">
<FileText className="h-5 w-5" />
Details
</CardTitle>
</CardHeader>
<CardContent className="space-y-3 text-sm">
<div className="flex justify-between">
<span className="text-muted-foreground">Slug</span>
<code className="rounded bg-muted px-1.5 py-0.5 text-xs">{agentType.slug}</code>
</div>
<div className="flex justify-between">
<span className="text-muted-foreground">Sort Order</span>
<span>{agentType.sort_order}</span>
</div>
<div className="flex justify-between">
<span className="text-muted-foreground">Created</span>
<span>
{new Date(agentType.created_at).toLocaleDateString('en-US', {
year: 'numeric',
month: 'short',
day: 'numeric',
})}
</span>
</div>
</CardContent>
</Card>
{/* Danger Zone */}
<Card className="border-destructive/50">
<CardHeader>

View File

@@ -3,11 +3,15 @@
*
* React Hook Form-based form for creating and editing agent types.
* Features tabbed interface for organizing form sections.
*
* Uses reusable form utilities for:
* - Validation error handling with toast notifications
* - Safe API-to-form data transformation with defaults
*/
'use client';
import { useEffect, useState } from 'react';
import { useEffect, useState, useCallback, useMemo } from 'react';
import { useForm, Controller } from 'react-hook-form';
import { zodResolver } from '@hookform/resolvers/zod';
import { Button } from '@/components/ui/button';
@@ -32,19 +36,89 @@ import {
type AgentTypeCreateFormValues,
AVAILABLE_MODELS,
AVAILABLE_MCP_SERVERS,
AGENT_TYPE_CATEGORIES,
defaultAgentTypeValues,
generateSlug,
} from '@/lib/validations/agentType';
import type { AgentTypeResponse } from '@/lib/api/types/agentTypes';
import { useValidationErrorHandler, deepMergeWithDefaults, isNumber } from '@/lib/forms';
interface AgentTypeFormProps {
agentType?: AgentTypeResponse;
onSubmit: (data: AgentTypeCreateFormValues) => void;
onSubmit: (data: AgentTypeCreateFormValues) => void | Promise<void>;
onCancel: () => void;
isSubmitting?: boolean;
className?: string;
}
// Tab navigation mapping for validation errors
const TAB_FIELD_MAPPING = {
name: 'basic',
slug: 'basic',
description: 'basic',
expertise: 'basic',
is_active: 'basic',
// Category and display fields
category: 'basic',
icon: 'basic',
color: 'basic',
sort_order: 'basic',
typical_tasks: 'basic',
collaboration_hints: 'basic',
primary_model: 'model',
fallback_models: 'model',
model_params: 'model',
mcp_servers: 'permissions',
tool_permissions: 'permissions',
personality_prompt: 'personality',
} as const;
/**
* Transform API response to form values with safe defaults
*
* Uses deepMergeWithDefaults for most fields, with special handling
* for model_params which needs numeric type validation.
*/
function transformAgentTypeToFormValues(
agentType: AgentTypeResponse | undefined
): AgentTypeCreateFormValues {
if (!agentType) return defaultAgentTypeValues;
// model_params needs special handling for numeric validation
const modelParams = agentType.model_params ?? {};
const safeModelParams = {
temperature: isNumber(modelParams.temperature) ? modelParams.temperature : 0.7,
max_tokens: isNumber(modelParams.max_tokens) ? modelParams.max_tokens : 8192,
top_p: isNumber(modelParams.top_p) ? modelParams.top_p : 0.95,
};
// Merge with defaults, then override model_params with safe version
const merged = deepMergeWithDefaults(defaultAgentTypeValues, {
name: agentType.name,
slug: agentType.slug,
description: agentType.description,
expertise: agentType.expertise,
personality_prompt: agentType.personality_prompt,
primary_model: agentType.primary_model,
fallback_models: agentType.fallback_models,
mcp_servers: agentType.mcp_servers,
tool_permissions: agentType.tool_permissions,
is_active: agentType.is_active,
// Category and display fields
category: agentType.category,
icon: agentType.icon,
color: agentType.color,
sort_order: agentType.sort_order ?? 0,
typical_tasks: agentType.typical_tasks ?? [],
collaboration_hints: agentType.collaboration_hints ?? [],
});
return {
...merged,
model_params: safeModelParams,
};
}
export function AgentTypeForm({
agentType,
onSubmit,
@@ -55,29 +129,16 @@ export function AgentTypeForm({
const isEditing = !!agentType;
const [activeTab, setActiveTab] = useState('basic');
const [expertiseInput, setExpertiseInput] = useState('');
const [typicalTaskInput, setTypicalTaskInput] = useState('');
const [collaborationHintInput, setCollaborationHintInput] = useState('');
// Memoize initial values transformation
const initialValues = useMemo(() => transformAgentTypeToFormValues(agentType), [agentType]);
// Always use create schema for validation - editing requires all fields too
const form = useForm<AgentTypeCreateFormValues>({
resolver: zodResolver(agentTypeCreateSchema),
defaultValues: agentType
? {
name: agentType.name,
slug: agentType.slug,
description: agentType.description,
expertise: agentType.expertise,
personality_prompt: agentType.personality_prompt,
primary_model: agentType.primary_model,
fallback_models: agentType.fallback_models,
model_params: (agentType.model_params ?? {
temperature: 0.7,
max_tokens: 8192,
top_p: 0.95,
}) as AgentTypeCreateFormValues['model_params'],
mcp_servers: agentType.mcp_servers,
tool_permissions: agentType.tool_permissions,
is_active: agentType.is_active,
}
: defaultAgentTypeValues,
defaultValues: initialValues,
});
const {
@@ -89,11 +150,28 @@ export function AgentTypeForm({
formState: { errors },
} = form;
// Use the reusable validation error handler hook
const { onValidationError } = useValidationErrorHandler<AgentTypeCreateFormValues>({
tabMapping: TAB_FIELD_MAPPING,
setActiveTab,
});
const watchName = watch('name');
/* istanbul ignore next -- defensive fallback, expertise always has default */
const watchExpertise = watch('expertise') || [];
/* istanbul ignore next -- defensive fallback, mcp_servers always has default */
const watchMcpServers = watch('mcp_servers') || [];
/* istanbul ignore next -- defensive fallback, typical_tasks always has default */
const watchTypicalTasks = watch('typical_tasks') || [];
/* istanbul ignore next -- defensive fallback, collaboration_hints always has default */
const watchCollaborationHints = watch('collaboration_hints') || [];
// Reset form when agentType changes (e.g., switching to edit mode)
useEffect(() => {
if (agentType) {
form.reset(initialValues);
}
}, [agentType?.id, form, initialValues]);
// Auto-generate slug from name for new agent types
useEffect(() => {
@@ -132,8 +210,50 @@ export function AgentTypeForm({
}
};
const handleAddTypicalTask = () => {
if (typicalTaskInput.trim()) {
const newTask = typicalTaskInput.trim();
if (!watchTypicalTasks.includes(newTask)) {
setValue('typical_tasks', [...watchTypicalTasks, newTask]);
}
setTypicalTaskInput('');
}
};
const handleRemoveTypicalTask = (task: string) => {
setValue(
'typical_tasks',
watchTypicalTasks.filter((t) => t !== task)
);
};
const handleAddCollaborationHint = () => {
if (collaborationHintInput.trim()) {
const newHint = collaborationHintInput.trim().toLowerCase();
if (!watchCollaborationHints.includes(newHint)) {
setValue('collaboration_hints', [...watchCollaborationHints, newHint]);
}
setCollaborationHintInput('');
}
};
const handleRemoveCollaborationHint = (hint: string) => {
setValue(
'collaboration_hints',
watchCollaborationHints.filter((h) => h !== hint)
);
};
// Handle form submission with validation
const onFormSubmit = useCallback(
(e: React.FormEvent<HTMLFormElement>) => {
return handleSubmit(onSubmit, onValidationError)(e);
},
[handleSubmit, onSubmit, onValidationError]
);
return (
<form onSubmit={handleSubmit(onSubmit)} className={className}>
<form onSubmit={onFormSubmit} className={className}>
{/* Header */}
<div className="mb-6 flex items-center gap-4">
<Button type="button" variant="ghost" size="icon" onClick={onCancel}>
@@ -311,6 +431,188 @@ export function AgentTypeForm({
</div>
</CardContent>
</Card>
{/* Category & Display Card */}
<Card>
<CardHeader>
<CardTitle>Category & Display</CardTitle>
<CardDescription>
Organize and customize how this agent type appears in the UI
</CardDescription>
</CardHeader>
<CardContent className="space-y-6">
<div className="grid gap-4 md:grid-cols-2">
<div className="space-y-2">
<Label htmlFor="category">Category</Label>
<Controller
name="category"
control={control}
render={({ field }) => (
<Select
value={field.value ?? ''}
onValueChange={(val) => field.onChange(val || null)}
>
<SelectTrigger id="category">
<SelectValue placeholder="Select category" />
</SelectTrigger>
<SelectContent>
{AGENT_TYPE_CATEGORIES.map((cat) => (
<SelectItem key={cat.value} value={cat.value}>
{cat.label}
</SelectItem>
))}
</SelectContent>
</Select>
)}
/>
<p className="text-xs text-muted-foreground">
Group agents by their primary role
</p>
</div>
<div className="space-y-2">
<Label htmlFor="sort_order">Sort Order</Label>
<Input
id="sort_order"
type="number"
min={0}
max={1000}
{...register('sort_order', { valueAsNumber: true })}
aria-invalid={!!errors.sort_order}
/>
{errors.sort_order && (
<p className="text-sm text-destructive" role="alert">
{errors.sort_order.message}
</p>
)}
<p className="text-xs text-muted-foreground">Display order within category</p>
</div>
</div>
<div className="grid gap-4 md:grid-cols-2">
<div className="space-y-2">
<Label htmlFor="icon">Icon</Label>
<Input
id="icon"
placeholder="e.g., git-branch"
{...register('icon')}
aria-invalid={!!errors.icon}
/>
{errors.icon && (
<p className="text-sm text-destructive" role="alert">
{errors.icon.message}
</p>
)}
<p className="text-xs text-muted-foreground">Lucide icon name for UI display</p>
</div>
<div className="space-y-2">
<Label htmlFor="color">Color</Label>
<div className="flex gap-2">
<Input
id="color"
placeholder="#3B82F6"
{...register('color')}
aria-invalid={!!errors.color}
className="flex-1"
/>
<Controller
name="color"
control={control}
render={({ field }) => (
<input
type="color"
value={field.value ?? '#3B82F6'}
onChange={(e) => field.onChange(e.target.value)}
className="h-9 w-9 cursor-pointer rounded border"
/>
)}
/>
</div>
{errors.color && (
<p className="text-sm text-destructive" role="alert">
{errors.color.message}
</p>
)}
<p className="text-xs text-muted-foreground">Hex color for visual distinction</p>
</div>
</div>
<Separator />
<div className="space-y-2">
<Label>Typical Tasks</Label>
<p className="text-sm text-muted-foreground">Tasks this agent type excels at</p>
<div className="flex gap-2">
<Input
placeholder="e.g., Design system architecture"
value={typicalTaskInput}
onChange={(e) => setTypicalTaskInput(e.target.value)}
onKeyDown={(e) => {
if (e.key === 'Enter') {
e.preventDefault();
handleAddTypicalTask();
}
}}
/>
<Button type="button" variant="outline" onClick={handleAddTypicalTask}>
Add
</Button>
</div>
<div className="flex flex-wrap gap-2 pt-2">
{watchTypicalTasks.map((task) => (
<Badge key={task} variant="secondary" className="gap-1">
{task}
<button
type="button"
className="ml-1 rounded-full hover:bg-muted"
onClick={() => handleRemoveTypicalTask(task)}
aria-label={`Remove ${task}`}
>
<X className="h-3 w-3" />
</button>
</Badge>
))}
</div>
</div>
<div className="space-y-2">
<Label>Collaboration Hints</Label>
<p className="text-sm text-muted-foreground">
Agent slugs that work well with this type
</p>
<div className="flex gap-2">
<Input
placeholder="e.g., backend-engineer"
value={collaborationHintInput}
onChange={(e) => setCollaborationHintInput(e.target.value)}
onKeyDown={(e) => {
if (e.key === 'Enter') {
e.preventDefault();
handleAddCollaborationHint();
}
}}
/>
<Button type="button" variant="outline" onClick={handleAddCollaborationHint}>
Add
</Button>
</div>
<div className="flex flex-wrap gap-2 pt-2">
{watchCollaborationHints.map((hint) => (
<Badge key={hint} variant="outline" className="gap-1">
{hint}
<button
type="button"
className="ml-1 rounded-full hover:bg-muted"
onClick={() => handleRemoveCollaborationHint(hint)}
aria-label={`Remove ${hint}`}
>
<X className="h-3 w-3" />
</button>
</Badge>
))}
</div>
</div>
</CardContent>
</Card>
</TabsContent>
{/* Model Configuration Tab */}

View File

@@ -1,8 +1,8 @@
/**
* AgentTypeList Component
*
* Displays a grid of agent type cards with search and filter functionality.
* Used on the main agent types page for browsing and selecting agent types.
* Displays agent types in grid or list view with search, status, and category filters.
* Shows icon, color accent, and category for each agent type.
*/
'use client';
@@ -20,8 +20,14 @@ import {
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
import { Bot, Plus, Search, Cpu } from 'lucide-react';
import type { AgentTypeResponse } from '@/lib/api/types/agentTypes';
import { ToggleGroup, ToggleGroupItem } from '@/components/ui/toggle-group';
import { Bot, Plus, Search, Cpu, LayoutGrid, List } from 'lucide-react';
import { DynamicIcon } from '@/components/ui/dynamic-icon';
import type { AgentTypeResponse, AgentTypeCategory } from '@/lib/api/types/agentTypes';
import { CATEGORY_METADATA } from '@/lib/api/types/agentTypes';
import { AGENT_TYPE_CATEGORIES } from '@/lib/validations/agentType';
export type ViewMode = 'grid' | 'list';
interface AgentTypeListProps {
agentTypes: AgentTypeResponse[];
@@ -30,6 +36,10 @@ interface AgentTypeListProps {
onSearchChange: (query: string) => void;
statusFilter: string;
onStatusFilterChange: (status: string) => void;
categoryFilter: string;
onCategoryFilterChange: (category: string) => void;
viewMode: ViewMode;
onViewModeChange: (mode: ViewMode) => void;
onSelect: (id: string) => void;
onCreate: () => void;
className?: string;
@@ -60,11 +70,36 @@ function AgentTypeStatusBadge({ isActive }: { isActive: boolean }) {
}
/**
* Loading skeleton for agent type cards
* Category badge with color
*/
function CategoryBadge({ category }: { category: AgentTypeCategory | null }) {
if (!category) return null;
const meta = CATEGORY_METADATA[category];
if (!meta) return null;
return (
<Badge
variant="outline"
className="text-xs font-medium"
style={{
borderColor: meta.color,
color: meta.color,
backgroundColor: `${meta.color}10`,
}}
>
{meta.label}
</Badge>
);
}
/**
* Loading skeleton for agent type cards (grid view)
*/
function AgentTypeCardSkeleton() {
return (
<Card className="h-[200px]">
<Card className="h-[220px] overflow-hidden">
<div className="h-1 w-full bg-muted" />
<CardHeader className="pb-3">
<div className="flex items-start justify-between">
<Skeleton className="h-10 w-10 rounded-lg" />
@@ -91,6 +126,23 @@ function AgentTypeCardSkeleton() {
);
}
/**
* Loading skeleton for list view
*/
function AgentTypeListSkeleton() {
return (
<div className="flex items-center gap-4 rounded-lg border p-4">
<Skeleton className="h-12 w-12 rounded-lg" />
<div className="flex-1 space-y-2">
<Skeleton className="h-5 w-48" />
<Skeleton className="h-4 w-96" />
</div>
<Skeleton className="h-5 w-20" />
<Skeleton className="h-5 w-16" />
</div>
);
}
/**
* Extract model display name from model ID
*/
@@ -103,6 +155,169 @@ function getModelDisplayName(modelId: string): string {
return modelId;
}
/**
* Grid card view for agent type
*/
function AgentTypeGridCard({
type,
onSelect,
}: {
type: AgentTypeResponse;
onSelect: (id: string) => void;
}) {
const agentColor = type.color || '#3B82F6';
return (
<Card
className="cursor-pointer overflow-hidden transition-all hover:shadow-lg"
onClick={() => onSelect(type.id)}
role="button"
tabIndex={0}
onKeyDown={(e) => {
if (e.key === 'Enter' || e.key === ' ') {
e.preventDefault();
onSelect(type.id);
}
}}
aria-label={`View ${type.name} agent type`}
style={{
borderTopColor: agentColor,
borderTopWidth: '3px',
}}
>
<CardHeader className="pb-3">
<div className="flex items-start justify-between">
<div
className="flex h-11 w-11 items-center justify-center rounded-lg"
style={{
backgroundColor: `${agentColor}15`,
}}
>
<DynamicIcon
name={type.icon}
className="h-5 w-5"
style={{ color: agentColor }}
fallback="bot"
/>
</div>
<div className="flex flex-col items-end gap-1">
<AgentTypeStatusBadge isActive={type.is_active} />
<CategoryBadge category={type.category} />
</div>
</div>
<CardTitle className="mt-3 line-clamp-1">{type.name}</CardTitle>
<CardDescription className="line-clamp-2">
{type.description || 'No description provided'}
</CardDescription>
</CardHeader>
<CardContent>
<div className="space-y-3">
{/* Expertise tags */}
<div className="flex flex-wrap gap-1">
{type.expertise.slice(0, 3).map((skill) => (
<Badge key={skill} variant="secondary" className="text-xs">
{skill}
</Badge>
))}
{type.expertise.length > 3 && (
<Badge variant="outline" className="text-xs">
+{type.expertise.length - 3}
</Badge>
)}
{type.expertise.length === 0 && (
<span className="text-xs text-muted-foreground">No expertise defined</span>
)}
</div>
<Separator />
{/* Metadata */}
<div className="flex items-center justify-between text-sm text-muted-foreground">
<div className="flex items-center gap-1">
<Cpu className="h-3.5 w-3.5" />
<span className="text-xs">{getModelDisplayName(type.primary_model)}</span>
</div>
<div className="flex items-center gap-1">
<Bot className="h-3.5 w-3.5" />
<span className="text-xs">{type.instance_count} instances</span>
</div>
</div>
</div>
</CardContent>
</Card>
);
}
/**
* List row view for agent type
*/
function AgentTypeListRow({
type,
onSelect,
}: {
type: AgentTypeResponse;
onSelect: (id: string) => void;
}) {
const agentColor = type.color || '#3B82F6';
return (
<div
className="flex cursor-pointer items-center gap-4 rounded-lg border p-4 transition-all hover:border-primary hover:shadow-md"
onClick={() => onSelect(type.id)}
role="button"
tabIndex={0}
onKeyDown={(e) => {
if (e.key === 'Enter' || e.key === ' ') {
e.preventDefault();
onSelect(type.id);
}
}}
aria-label={`View ${type.name} agent type`}
style={{
borderLeftColor: agentColor,
borderLeftWidth: '4px',
}}
>
{/* Icon */}
<div
className="flex h-12 w-12 shrink-0 items-center justify-center rounded-lg"
style={{ backgroundColor: `${agentColor}15` }}
>
<DynamicIcon
name={type.icon}
className="h-6 w-6"
style={{ color: agentColor }}
fallback="bot"
/>
</div>
{/* Main content */}
<div className="min-w-0 flex-1">
<div className="flex items-center gap-2">
<h3 className="font-semibold">{type.name}</h3>
<CategoryBadge category={type.category} />
</div>
<p className="line-clamp-1 text-sm text-muted-foreground">
{type.description || 'No description'}
</p>
<div className="mt-1 flex items-center gap-3 text-xs text-muted-foreground">
<span className="flex items-center gap-1">
<Cpu className="h-3 w-3" />
{getModelDisplayName(type.primary_model)}
</span>
<span>{type.expertise.length} expertise areas</span>
<span>{type.instance_count} instances</span>
</div>
</div>
{/* Status */}
<div className="shrink-0">
<AgentTypeStatusBadge isActive={type.is_active} />
</div>
</div>
);
}
export function AgentTypeList({
agentTypes,
isLoading = false,
@@ -110,6 +325,10 @@ export function AgentTypeList({
onSearchChange,
statusFilter,
onStatusFilterChange,
categoryFilter,
onCategoryFilterChange,
viewMode,
onViewModeChange,
onSelect,
onCreate,
className,
@@ -131,7 +350,7 @@ export function AgentTypeList({
</div>
{/* Filters */}
<div className="mb-6 flex flex-col gap-4 sm:flex-row">
<div className="mb-6 flex flex-col gap-4 sm:flex-row sm:items-center">
<div className="relative flex-1">
<Search className="absolute left-3 top-1/2 h-4 w-4 -translate-y-1/2 text-muted-foreground" />
<Input
@@ -142,8 +361,25 @@ export function AgentTypeList({
aria-label="Search agent types"
/>
</div>
{/* Category Filter */}
<Select value={categoryFilter} onValueChange={onCategoryFilterChange}>
<SelectTrigger className="w-full sm:w-44" aria-label="Filter by category">
<SelectValue placeholder="All Categories" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">All Categories</SelectItem>
{AGENT_TYPE_CATEGORIES.map((cat) => (
<SelectItem key={cat.value} value={cat.value}>
{cat.label}
</SelectItem>
))}
</SelectContent>
</Select>
{/* Status Filter */}
<Select value={statusFilter} onValueChange={onStatusFilterChange}>
<SelectTrigger className="w-full sm:w-40" aria-label="Filter by status">
<SelectTrigger className="w-full sm:w-36" aria-label="Filter by status">
<SelectValue placeholder="Status" />
</SelectTrigger>
<SelectContent>
@@ -152,10 +388,25 @@ export function AgentTypeList({
<SelectItem value="inactive">Inactive</SelectItem>
</SelectContent>
</Select>
{/* View Mode Toggle */}
<ToggleGroup
type="single"
value={viewMode}
onValueChange={(value: string) => value && onViewModeChange(value as ViewMode)}
className="hidden sm:flex"
>
<ToggleGroupItem value="grid" aria-label="Grid view" size="sm">
<LayoutGrid className="h-4 w-4" />
</ToggleGroupItem>
<ToggleGroupItem value="list" aria-label="List view" size="sm">
<List className="h-4 w-4" />
</ToggleGroupItem>
</ToggleGroup>
</div>
{/* Loading State */}
{isLoading && (
{/* Loading State - Grid */}
{isLoading && viewMode === 'grid' && (
<div className="grid gap-4 md:grid-cols-2 lg:grid-cols-3">
{[1, 2, 3, 4, 5, 6].map((i) => (
<AgentTypeCardSkeleton key={i} />
@@ -163,71 +414,29 @@ export function AgentTypeList({
</div>
)}
{/* Agent Type Grid */}
{!isLoading && agentTypes.length > 0 && (
{/* Loading State - List */}
{isLoading && viewMode === 'list' && (
<div className="space-y-3">
{[1, 2, 3, 4, 5, 6].map((i) => (
<AgentTypeListSkeleton key={i} />
))}
</div>
)}
{/* Agent Type Grid View */}
{!isLoading && agentTypes.length > 0 && viewMode === 'grid' && (
<div className="grid gap-4 md:grid-cols-2 lg:grid-cols-3">
{agentTypes.map((type) => (
<Card
key={type.id}
className="cursor-pointer transition-all hover:border-primary hover:shadow-md"
onClick={() => onSelect(type.id)}
role="button"
tabIndex={0}
onKeyDown={(e) => {
if (e.key === 'Enter' || e.key === ' ') {
e.preventDefault();
onSelect(type.id);
}
}}
aria-label={`View ${type.name} agent type`}
>
<CardHeader className="pb-3">
<div className="flex items-start justify-between">
<div className="flex h-10 w-10 items-center justify-center rounded-lg bg-primary/10">
<Bot className="h-5 w-5 text-primary" />
</div>
<AgentTypeStatusBadge isActive={type.is_active} />
</div>
<CardTitle className="mt-3">{type.name}</CardTitle>
<CardDescription className="line-clamp-2">
{type.description || 'No description provided'}
</CardDescription>
</CardHeader>
<CardContent>
<div className="space-y-3">
{/* Expertise tags */}
<div className="flex flex-wrap gap-1">
{type.expertise.slice(0, 3).map((skill) => (
<Badge key={skill} variant="secondary" className="text-xs">
{skill}
</Badge>
))}
{type.expertise.length > 3 && (
<Badge variant="outline" className="text-xs">
+{type.expertise.length - 3}
</Badge>
)}
{type.expertise.length === 0 && (
<span className="text-xs text-muted-foreground">No expertise defined</span>
)}
</div>
<AgentTypeGridCard key={type.id} type={type} onSelect={onSelect} />
))}
</div>
)}
<Separator />
{/* Metadata */}
<div className="flex items-center justify-between text-sm text-muted-foreground">
<div className="flex items-center gap-1">
<Cpu className="h-3.5 w-3.5" />
<span className="text-xs">{getModelDisplayName(type.primary_model)}</span>
</div>
<div className="flex items-center gap-1">
<Bot className="h-3.5 w-3.5" />
<span className="text-xs">{type.instance_count} instances</span>
</div>
</div>
</div>
</CardContent>
</Card>
{/* Agent Type List View */}
{!isLoading && agentTypes.length > 0 && viewMode === 'list' && (
<div className="space-y-3">
{agentTypes.map((type) => (
<AgentTypeListRow key={type.id} type={type} onSelect={onSelect} />
))}
</div>
)}
@@ -238,11 +447,11 @@ export function AgentTypeList({
<Bot className="mx-auto h-12 w-12 text-muted-foreground" />
<h3 className="mt-4 font-semibold">No agent types found</h3>
<p className="text-muted-foreground">
{searchQuery || statusFilter !== 'all'
{searchQuery || statusFilter !== 'all' || categoryFilter !== 'all'
? 'Try adjusting your search or filters'
: 'Create your first agent type to get started'}
</p>
{!searchQuery && statusFilter === 'all' && (
{!searchQuery && statusFilter === 'all' && categoryFilter === 'all' && (
<Button onClick={onCreate} className="mt-4">
<Plus className="mr-2 h-4 w-4" />
Create Agent Type

View File

@@ -5,5 +5,5 @@
*/
export { AgentTypeForm } from './AgentTypeForm';
export { AgentTypeList } from './AgentTypeList';
export { AgentTypeList, type ViewMode } from './AgentTypeList';
export { AgentTypeDetail } from './AgentTypeDetail';

View File

@@ -0,0 +1,133 @@
/**
* FormSelect Component
*
* Reusable Select field with Controller integration for react-hook-form.
* Handles label, error display, and description automatically.
*
* @module components/forms/FormSelect
*/
'use client';
import { Controller, type Control, type FieldValues, type Path } from 'react-hook-form';
import { Label } from '@/components/ui/label';
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select';
export interface SelectOption {
value: string;
label: string;
}
export interface FormSelectProps<T extends FieldValues> {
/** Field name (must be a valid path in the form) */
name: Path<T>;
/** Form control from useForm */
control: Control<T>;
/** Field label */
label: string;
/** Available options */
options: SelectOption[];
/** Is field required? Shows asterisk if true */
required?: boolean;
/** Placeholder text when no value selected */
placeholder?: string;
/** Helper text below the field */
description?: string;
/** Disable the select */
disabled?: boolean;
/** Additional class name */
className?: string;
}
/**
* FormSelect - Controlled Select field for react-hook-form
*
* Automatically handles:
* - Controller wrapper for react-hook-form
* - Label with required indicator
* - Error message display
* - Description/helper text
* - Accessibility attributes
*
* @example
* ```tsx
* <FormSelect
* name="primary_model"
* control={form.control}
* label="Primary Model"
* required
* options={[
* { value: 'claude-opus', label: 'Claude Opus' },
* { value: 'claude-sonnet', label: 'Claude Sonnet' },
* ]}
* description="Main model used for this agent"
* />
* ```
*/
export function FormSelect<T extends FieldValues>({
name,
control,
label,
options,
required = false,
placeholder,
description,
disabled = false,
className,
}: FormSelectProps<T>) {
const selectId = String(name);
const errorId = `${selectId}-error`;
const descriptionId = description ? `${selectId}-description` : undefined;
return (
<Controller
name={name}
control={control}
render={({ field, fieldState }) => (
<div className={className}>
<div className="space-y-2">
<Label htmlFor={selectId}>
{label}
{required && <span className="text-destructive"> *</span>}
</Label>
<Select value={field.value ?? ''} onValueChange={field.onChange} disabled={disabled}>
<SelectTrigger
id={selectId}
aria-invalid={!!fieldState.error}
aria-describedby={
[fieldState.error ? errorId : null, descriptionId].filter(Boolean).join(' ') ||
undefined
}
>
<SelectValue placeholder={placeholder ?? `Select ${label.toLowerCase()}`} />
</SelectTrigger>
<SelectContent>
{options.map((option) => (
<SelectItem key={option.value} value={option.value}>
{option.label}
</SelectItem>
))}
</SelectContent>
</Select>
{fieldState.error && (
<p id={errorId} className="text-sm text-destructive" role="alert">
{fieldState.error.message}
</p>
)}
{description && (
<p id={descriptionId} className="text-xs text-muted-foreground">
{description}
</p>
)}
</div>
</div>
)}
/>
);
}

View File

@@ -0,0 +1,101 @@
/**
* FormTextarea Component
*
* Reusable Textarea field for react-hook-form with register integration.
* Handles label, error display, and description automatically.
*
* @module components/forms/FormTextarea
*/
'use client';
import { ComponentProps } from 'react';
import type { FieldError, UseFormRegisterReturn } from 'react-hook-form';
import { Label } from '@/components/ui/label';
import { Textarea } from '@/components/ui/textarea';
export interface FormTextareaProps extends Omit<ComponentProps<typeof Textarea>, 'children'> {
/** Field label */
label: string;
/** Field name (optional if provided via register) */
name?: string;
/** Is field required? Shows asterisk if true */
required?: boolean;
/** Form error from react-hook-form */
error?: FieldError;
/** Helper text below the field */
description?: string;
/** Register return object from useForm */
registration?: UseFormRegisterReturn;
}
/**
* FormTextarea - Textarea field for react-hook-form
*
* Automatically handles:
* - Label with required indicator
* - Error message display
* - Description/helper text
* - Accessibility attributes
*
* @example
* ```tsx
* <FormTextarea
* label="Personality Prompt"
* required
* error={errors.personality_prompt}
* rows={10}
* {...register('personality_prompt')}
* />
* ```
*/
export function FormTextarea({
label,
name: explicitName,
required = false,
error,
description,
registration,
...textareaProps
}: FormTextareaProps) {
// Extract name from props or registration
const registerName =
'name' in textareaProps ? (textareaProps as { name: string }).name : undefined;
const name = explicitName || registerName || registration?.name;
if (!name) {
throw new Error('FormTextarea: name must be provided either explicitly or via register()');
}
const errorId = error ? `${name}-error` : undefined;
const descriptionId = description ? `${name}-description` : undefined;
const ariaDescribedBy = [errorId, descriptionId].filter(Boolean).join(' ') || undefined;
// Merge registration props with other props
const mergedProps = registration ? { ...registration, ...textareaProps } : textareaProps;
return (
<div className="space-y-2">
<Label htmlFor={name}>
{label}
{required && <span className="text-destructive"> *</span>}
</Label>
{description && (
<p id={descriptionId} className="text-sm text-muted-foreground">
{description}
</p>
)}
<Textarea
id={name}
aria-invalid={!!error}
aria-describedby={ariaDescribedBy}
{...mergedProps}
/>
{error && (
<p id={errorId} className="text-sm text-destructive" role="alert">
{error.message}
</p>
)}
</div>
);
}

View File

@@ -1,5 +1,9 @@
// Shared form components and utilities
export { FormField } from './FormField';
export type { FormFieldProps } from './FormField';
export { FormSelect } from './FormSelect';
export type { FormSelectProps, SelectOption } from './FormSelect';
export { FormTextarea } from './FormTextarea';
export type { FormTextareaProps } from './FormTextarea';
export { useFormError } from './useFormError';
export type { UseFormErrorReturn } from './useFormError';

View File

@@ -0,0 +1,84 @@
/**
* DynamicIcon Component
*
* Renders Lucide icons dynamically by name string.
* Useful when icon names come from data (e.g., database).
*/
import * as LucideIcons from 'lucide-react';
import type { LucideProps } from 'lucide-react';
/**
* Map of icon names to their components.
* Uses kebab-case names (e.g., 'clipboard-check') as keys.
*/
const iconMap: Record<string, React.ComponentType<LucideProps>> = {
// Development
'clipboard-check': LucideIcons.ClipboardCheck,
briefcase: LucideIcons.Briefcase,
'file-text': LucideIcons.FileText,
'git-branch': LucideIcons.GitBranch,
code: LucideIcons.Code,
server: LucideIcons.Server,
layout: LucideIcons.Layout,
smartphone: LucideIcons.Smartphone,
// Design
palette: LucideIcons.Palette,
search: LucideIcons.Search,
// Quality
shield: LucideIcons.Shield,
'shield-check': LucideIcons.ShieldCheck,
// Operations
settings: LucideIcons.Settings,
'settings-2': LucideIcons.Settings2,
// AI/ML
brain: LucideIcons.Brain,
microscope: LucideIcons.Microscope,
eye: LucideIcons.Eye,
'message-square': LucideIcons.MessageSquare,
// Data
'bar-chart': LucideIcons.BarChart,
database: LucideIcons.Database,
// Leadership
users: LucideIcons.Users,
target: LucideIcons.Target,
// Domain Expert
calculator: LucideIcons.Calculator,
'heart-pulse': LucideIcons.HeartPulse,
'flask-conical': LucideIcons.FlaskConical,
lightbulb: LucideIcons.Lightbulb,
'book-open': LucideIcons.BookOpen,
// Generic
bot: LucideIcons.Bot,
cpu: LucideIcons.Cpu,
};
interface DynamicIconProps extends Omit<LucideProps, 'name'> {
/** Icon name in kebab-case (e.g., 'clipboard-check', 'bot') */
name: string | null | undefined;
/** Fallback icon name if the specified icon is not found */
fallback?: string;
}
/**
* Renders a Lucide icon dynamically by name.
*
* @example
* ```tsx
* <DynamicIcon name="clipboard-check" className="h-5 w-5" />
* <DynamicIcon name={agent.icon} fallback="bot" />
* ```
*/
export function DynamicIcon({ name, fallback = 'bot', ...props }: DynamicIconProps) {
const iconName = name || fallback;
const IconComponent = iconMap[iconName] || iconMap[fallback] || LucideIcons.Bot;
return <IconComponent {...props} />;
}
/**
* Get available icon names for validation or display
*/
export function getAvailableIconNames(): string[] {
return Object.keys(iconMap);
}

View File

@@ -0,0 +1,93 @@
'use client';
import * as React from 'react';
import * as ToggleGroupPrimitive from '@radix-ui/react-toggle-group';
import { type VariantProps, cva } from 'class-variance-authority';
import { cn } from '@/lib/utils';
const toggleGroupVariants = cva(
'inline-flex items-center justify-center rounded-md border bg-transparent',
{
variants: {
variant: {
default: 'bg-transparent',
outline: 'border border-input',
},
},
defaultVariants: {
variant: 'outline',
},
}
);
const toggleGroupItemVariants = cva(
'inline-flex items-center justify-center whitespace-nowrap text-sm font-medium ring-offset-background transition-all focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50 data-[state=on]:bg-accent data-[state=on]:text-accent-foreground',
{
variants: {
variant: {
default: 'bg-transparent hover:bg-muted hover:text-muted-foreground',
outline: 'bg-transparent hover:bg-accent hover:text-accent-foreground',
},
size: {
default: 'h-10 px-3',
sm: 'h-9 px-2.5',
lg: 'h-11 px-5',
},
},
defaultVariants: {
variant: 'default',
size: 'default',
},
}
);
const ToggleGroupContext = React.createContext<VariantProps<typeof toggleGroupItemVariants>>({
size: 'default',
variant: 'default',
});
const ToggleGroup = React.forwardRef<
React.ElementRef<typeof ToggleGroupPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof ToggleGroupPrimitive.Root> &
VariantProps<typeof toggleGroupVariants> &
VariantProps<typeof toggleGroupItemVariants>
>(({ className, variant, size, children, ...props }, ref) => (
<ToggleGroupPrimitive.Root
ref={ref}
className={cn(toggleGroupVariants({ variant }), className)}
{...props}
>
<ToggleGroupContext.Provider value={{ variant, size }}>{children}</ToggleGroupContext.Provider>
</ToggleGroupPrimitive.Root>
));
ToggleGroup.displayName = ToggleGroupPrimitive.Root.displayName;
const ToggleGroupItem = React.forwardRef<
React.ElementRef<typeof ToggleGroupPrimitive.Item>,
React.ComponentPropsWithoutRef<typeof ToggleGroupPrimitive.Item> &
VariantProps<typeof toggleGroupItemVariants>
>(({ className, children, variant, size, ...props }, ref) => {
const context = React.useContext(ToggleGroupContext);
return (
<ToggleGroupPrimitive.Item
ref={ref}
className={cn(
toggleGroupItemVariants({
variant: context.variant || variant,
size: context.size || size,
}),
className
)}
{...props}
>
{children}
</ToggleGroupPrimitive.Item>
);
});
ToggleGroupItem.displayName = ToggleGroupPrimitive.Item.displayName;
export { ToggleGroup, ToggleGroupItem };

View File

@@ -44,10 +44,10 @@ const DEFAULT_PAGE_LIMIT = 20;
export function useAgentTypes(params: AgentTypeListParams = {}) {
const { user } = useAuth();
const { page = 1, limit = DEFAULT_PAGE_LIMIT, is_active = true, search } = params;
const { page = 1, limit = DEFAULT_PAGE_LIMIT, is_active = true, search, category } = params;
return useQuery({
queryKey: agentTypeKeys.list({ page, limit, is_active, search }),
queryKey: agentTypeKeys.list({ page, limit, is_active, search, category }),
queryFn: async (): Promise<AgentTypeListResponse> => {
const response = await apiClient.instance.get('/api/v1/agent-types', {
params: {
@@ -55,6 +55,7 @@ export function useAgentTypes(params: AgentTypeListParams = {}) {
limit,
is_active,
...(search ? { search } : {}),
...(category ? { category } : {}),
},
});
return response.data;

View File

@@ -6,13 +6,15 @@
* - Recent projects
* - Pending approvals
*
* Uses mock data until backend endpoints are available.
* Fetches real data from the API.
*
* @see Issue #53
*/
import { useQuery } from '@tanstack/react-query';
import type { Project, ProjectStatus } from '@/components/projects/types';
import { listProjects as listProjectsApi } from '@/lib/api/generated';
import type { ProjectResponse } from '@/lib/api/generated';
import type { AutonomyLevel, Project, ProjectStatus } from '@/components/projects/types';
// ============================================================================
// Types
@@ -52,118 +54,70 @@ export interface DashboardData {
}
// ============================================================================
// Mock Data
// Helpers
// ============================================================================
const mockStats: DashboardStats = {
activeProjects: 3,
runningAgents: 8,
openIssues: 24,
pendingApprovals: 2,
};
/**
* Format a date string as relative time (e.g., "2 minutes ago")
*/
function formatRelativeTime(dateStr: string): string {
const date = new Date(dateStr);
const now = new Date();
const diffMs = now.getTime() - date.getTime();
const diffMins = Math.floor(diffMs / 60000);
const diffHours = Math.floor(diffMins / 60);
const diffDays = Math.floor(diffHours / 24);
const diffWeeks = Math.floor(diffDays / 7);
const diffMonths = Math.floor(diffDays / 30);
const mockProjects: DashboardProject[] = [
{
id: 'proj-001',
name: 'E-Commerce Platform Redesign',
description: 'Complete redesign of the e-commerce platform with modern UI/UX',
status: 'active' as ProjectStatus,
autonomy_level: 'milestone',
created_at: '2025-11-15T10:00:00Z',
updated_at: '2025-12-30T14:30:00Z',
owner_id: 'user-001',
progress: 67,
openIssues: 12,
activeAgents: 4,
currentSprint: 'Sprint 3',
lastActivity: '2 minutes ago',
},
{
id: 'proj-002',
name: 'Mobile Banking App',
description: 'Native mobile app for banking services with biometric authentication',
status: 'active' as ProjectStatus,
autonomy_level: 'autonomous',
created_at: '2025-11-20T09:00:00Z',
updated_at: '2025-12-30T12:00:00Z',
owner_id: 'user-001',
progress: 45,
openIssues: 8,
activeAgents: 5,
currentSprint: 'Sprint 2',
lastActivity: '15 minutes ago',
},
{
id: 'proj-003',
name: 'Internal HR Portal',
description: 'Employee self-service portal for HR operations',
status: 'paused' as ProjectStatus,
autonomy_level: 'full_control',
created_at: '2025-10-01T08:00:00Z',
updated_at: '2025-12-28T16:00:00Z',
owner_id: 'user-001',
progress: 23,
openIssues: 5,
activeAgents: 0,
currentSprint: 'Sprint 1',
lastActivity: '2 days ago',
},
{
id: 'proj-004',
name: 'API Gateway Modernization',
description: 'Migrate legacy API gateway to cloud-native architecture',
status: 'active' as ProjectStatus,
autonomy_level: 'milestone',
created_at: '2025-12-01T11:00:00Z',
updated_at: '2025-12-30T10:00:00Z',
owner_id: 'user-001',
progress: 82,
openIssues: 3,
activeAgents: 2,
currentSprint: 'Sprint 4',
lastActivity: '1 hour ago',
},
{
id: 'proj-005',
name: 'Customer Analytics Dashboard',
description: 'Real-time analytics dashboard for customer behavior insights',
status: 'completed' as ProjectStatus,
autonomy_level: 'autonomous',
created_at: '2025-09-01T10:00:00Z',
updated_at: '2025-12-15T17:00:00Z',
owner_id: 'user-001',
progress: 100,
openIssues: 0,
activeAgents: 0,
lastActivity: '2 weeks ago',
},
{
id: 'proj-006',
name: 'DevOps Pipeline Automation',
description: 'Automate CI/CD pipelines with AI-assisted deployments',
status: 'active' as ProjectStatus,
autonomy_level: 'milestone',
created_at: '2025-12-10T14:00:00Z',
updated_at: '2025-12-30T09:00:00Z',
owner_id: 'user-001',
progress: 35,
openIssues: 6,
activeAgents: 3,
currentSprint: 'Sprint 1',
lastActivity: '30 minutes ago',
},
];
if (diffMins < 1) return 'Just now';
if (diffMins < 60) return `${diffMins} minute${diffMins > 1 ? 's' : ''} ago`;
if (diffHours < 24) return `${diffHours} hour${diffHours > 1 ? 's' : ''} ago`;
if (diffDays < 7) return `${diffDays} day${diffDays > 1 ? 's' : ''} ago`;
if (diffWeeks < 4) return `${diffWeeks} week${diffWeeks > 1 ? 's' : ''} ago`;
return `${diffMonths} month${diffMonths > 1 ? 's' : ''} ago`;
}
/**
* Maps API ProjectResponse to DashboardProject format
*/
function mapToDashboardProject(
project: ProjectResponse & Record<string, unknown>
): DashboardProject {
const updatedAt = project.updated_at || project.created_at || new Date().toISOString();
const createdAt = project.created_at || new Date().toISOString();
return {
id: project.id,
name: project.name,
description: project.description || undefined,
status: project.status as ProjectStatus,
autonomy_level: (project.autonomy_level || 'milestone') as AutonomyLevel,
created_at: createdAt,
updated_at: updatedAt,
owner_id: project.owner_id || 'unknown',
progress: (project.progress as number) || 0,
openIssues: (project.openIssues as number) || project.issue_count || 0,
activeAgents: (project.activeAgents as number) || project.agent_count || 0,
currentSprint: project.active_sprint_name || undefined,
lastActivity: formatRelativeTime(updatedAt),
};
}
// ============================================================================
// Mock Data (for pending approvals - no backend endpoint yet)
// ============================================================================
const mockApprovals: PendingApproval[] = [
{
id: 'approval-001',
type: 'sprint_boundary',
title: 'Sprint 3 Completion Review',
description: 'Review sprint deliverables and approve transition to Sprint 4',
title: 'Sprint 1 Completion Review',
description: 'Review sprint deliverables and approve transition to Sprint 2',
projectId: 'proj-001',
projectName: 'E-Commerce Platform Redesign',
requestedBy: 'Product Owner Agent',
requestedAt: '2025-12-30T14:00:00Z',
requestedAt: new Date().toISOString(),
priority: 'high',
},
{
@@ -171,10 +125,10 @@ const mockApprovals: PendingApproval[] = [
type: 'architecture_decision',
title: 'Database Migration Strategy',
description: 'Approve PostgreSQL to CockroachDB migration plan',
projectId: 'proj-004',
projectName: 'API Gateway Modernization',
projectId: 'proj-002',
projectName: 'Mobile Banking App',
requestedBy: 'Architect Agent',
requestedAt: '2025-12-30T10:30:00Z',
requestedAt: new Date(Date.now() - 3600000).toISOString(),
priority: 'medium',
},
];
@@ -192,17 +146,41 @@ export function useDashboard() {
return useQuery<DashboardData>({
queryKey: ['dashboard'],
queryFn: async () => {
// Simulate network delay
await new Promise((resolve) => setTimeout(resolve, 500));
// Fetch real projects from API
const response = await listProjectsApi({
query: {
limit: 6,
},
});
// Return mock data
// TODO: Replace with actual API call when backend is ready
// const response = await apiClient.get('/api/v1/dashboard');
// return response.data;
if (response.error) {
throw new Error('Failed to fetch dashboard data');
}
const projects = response.data.data.map((p) =>
mapToDashboardProject(p as ProjectResponse & Record<string, unknown>)
);
// Sort by updated_at (most recent first)
projects.sort(
(a, b) =>
new Date(b.updated_at || b.created_at).getTime() -
new Date(a.updated_at || a.created_at).getTime()
);
// Calculate stats from real data
const activeProjects = projects.filter((p) => p.status === 'active').length;
const runningAgents = projects.reduce((sum, p) => sum + p.activeAgents, 0);
const openIssues = projects.reduce((sum, p) => sum + p.openIssues, 0);
return {
stats: mockStats,
recentProjects: mockProjects,
stats: {
activeProjects,
runningAgents,
openIssues,
pendingApprovals: mockApprovals.length,
},
recentProjects: projects,
pendingApprovals: mockApprovals,
};
},
@@ -218,8 +196,24 @@ export function useDashboardStats() {
return useQuery<DashboardStats>({
queryKey: ['dashboard', 'stats'],
queryFn: async () => {
await new Promise((resolve) => setTimeout(resolve, 300));
return mockStats;
const response = await listProjectsApi({
query: { limit: 100 },
});
if (response.error) {
throw new Error('Failed to fetch stats');
}
const projects = response.data.data.map((p) =>
mapToDashboardProject(p as ProjectResponse & Record<string, unknown>)
);
return {
activeProjects: projects.filter((p) => p.status === 'active').length,
runningAgents: projects.reduce((sum, p) => sum + p.activeAgents, 0),
openIssues: projects.reduce((sum, p) => sum + p.openIssues, 0),
pendingApprovals: mockApprovals.length,
};
},
staleTime: 30000,
refetchInterval: 60000,
@@ -235,8 +229,26 @@ export function useRecentProjects(limit: number = 6) {
return useQuery<DashboardProject[]>({
queryKey: ['dashboard', 'recentProjects', limit],
queryFn: async () => {
await new Promise((resolve) => setTimeout(resolve, 400));
return mockProjects.slice(0, limit);
const response = await listProjectsApi({
query: { limit },
});
if (response.error) {
throw new Error('Failed to fetch recent projects');
}
const projects = response.data.data.map((p) =>
mapToDashboardProject(p as ProjectResponse & Record<string, unknown>)
);
// Sort by updated_at (most recent first)
projects.sort(
(a, b) =>
new Date(b.updated_at || b.created_at).getTime() -
new Date(a.updated_at || a.created_at).getTime()
);
return projects;
},
staleTime: 30000,
});
@@ -249,7 +261,7 @@ export function usePendingApprovals() {
return useQuery<PendingApproval[]>({
queryKey: ['dashboard', 'pendingApprovals'],
queryFn: async () => {
await new Promise((resolve) => setTimeout(resolve, 300));
// TODO: Fetch from real API when endpoint exists
return mockApprovals;
},
staleTime: 30000,

View File

@@ -5,6 +5,68 @@
* Used for type-safe API communication with the agent-types endpoints.
*/
/**
* Category classification for agent types
*/
export type AgentTypeCategory =
| 'development'
| 'design'
| 'quality'
| 'operations'
| 'ai_ml'
| 'data'
| 'leadership'
| 'domain_expert';
/**
* Metadata for each category including display label and description
*/
export const CATEGORY_METADATA: Record<
AgentTypeCategory,
{ label: string; description: string; color: string }
> = {
development: {
label: 'Development',
description: 'Product, project, and engineering roles',
color: '#3B82F6',
},
design: {
label: 'Design',
description: 'UI/UX and design research',
color: '#EC4899',
},
quality: {
label: 'Quality',
description: 'QA and security assurance',
color: '#10B981',
},
operations: {
label: 'Operations',
description: 'DevOps and MLOps engineering',
color: '#F59E0B',
},
ai_ml: {
label: 'AI & ML',
description: 'Machine learning specialists',
color: '#8B5CF6',
},
data: {
label: 'Data',
description: 'Data science and engineering',
color: '#06B6D4',
},
leadership: {
label: 'Leadership',
description: 'Technical leadership and facilitation',
color: '#F97316',
},
domain_expert: {
label: 'Domain Experts',
description: 'Industry and domain specialists',
color: '#84CC16',
},
};
/**
* Base agent type fields shared across create, update, and response schemas
*/
@@ -20,6 +82,13 @@ export interface AgentTypeBase {
mcp_servers: string[];
tool_permissions: Record<string, unknown>;
is_active: boolean;
// Category and display fields
category?: AgentTypeCategory | null;
icon?: string | null;
color?: string | null;
sort_order: number;
typical_tasks: string[];
collaboration_hints: string[];
}
/**
@@ -37,6 +106,13 @@ export interface AgentTypeCreate {
mcp_servers?: string[];
tool_permissions?: Record<string, unknown>;
is_active?: boolean;
// Category and display fields
category?: AgentTypeCategory | null;
icon?: string | null;
color?: string | null;
sort_order?: number;
typical_tasks?: string[];
collaboration_hints?: string[];
}
/**
@@ -54,6 +130,13 @@ export interface AgentTypeUpdate {
mcp_servers?: string[] | null;
tool_permissions?: Record<string, unknown> | null;
is_active?: boolean | null;
// Category and display fields
category?: AgentTypeCategory | null;
icon?: string | null;
color?: string | null;
sort_order?: number | null;
typical_tasks?: string[] | null;
collaboration_hints?: string[] | null;
}
/**
@@ -72,6 +155,13 @@ export interface AgentTypeResponse {
mcp_servers: string[];
tool_permissions: Record<string, unknown>;
is_active: boolean;
// Category and display fields
category: AgentTypeCategory | null;
icon: string | null;
color: string | null;
sort_order: number;
typical_tasks: string[];
collaboration_hints: string[];
created_at: string;
updated_at: string;
instance_count: number;
@@ -104,9 +194,15 @@ export interface AgentTypeListParams {
page?: number;
limit?: number;
is_active?: boolean;
category?: AgentTypeCategory;
search?: string;
}
/**
* Response type for grouped agent types by category
*/
export type AgentTypeGroupedResponse = Record<string, AgentTypeResponse[]>;
/**
* Model parameter configuration with typed fields
*/

View File

@@ -0,0 +1,118 @@
/**
* Validation Error Handler Hook
*
* Handles client-side Zod/react-hook-form validation errors with:
* - Toast notifications
* - Optional tab navigation
* - Debug logging
*
* @module lib/forms/hooks/useValidationErrorHandler
*/
'use client';
import { useCallback } from 'react';
import { toast } from 'sonner';
import type { FieldErrors, FieldValues } from 'react-hook-form';
import { getFirstValidationError } from '../utils/getFirstValidationError';
export interface TabFieldMapping {
/** Map of field names to tab values */
[fieldName: string]: string;
}
export interface UseValidationErrorHandlerOptions {
/**
* Map of field names (top-level) to tab values.
* When an error occurs, navigates to the tab containing the field.
*/
tabMapping?: TabFieldMapping;
/**
* Callback to set the active tab.
* Required if tabMapping is provided.
*/
setActiveTab?: (tab: string) => void;
/**
* Enable debug logging to console.
* @default false in production, true in development
*/
debug?: boolean;
/**
* Toast title for validation errors.
* @default 'Please fix form errors'
*/
toastTitle?: string;
}
export interface UseValidationErrorHandlerReturn<T extends FieldValues> {
/**
* Handler function to pass to react-hook-form's handleSubmit second argument.
* Shows toast, navigates to tab, and logs errors.
*/
onValidationError: (errors: FieldErrors<T>) => void;
}
/**
* Hook for handling client-side validation errors
*
* @example
* ```tsx
* const [activeTab, setActiveTab] = useState('basic');
*
* const { onValidationError } = useValidationErrorHandler({
* tabMapping: {
* name: 'basic',
* slug: 'basic',
* primary_model: 'model',
* model_params: 'model',
* },
* setActiveTab,
* });
*
* // In form:
* <form onSubmit={handleSubmit(onSuccess, onValidationError)}>
* ```
*/
export function useValidationErrorHandler<T extends FieldValues>(
options: UseValidationErrorHandlerOptions = {}
): UseValidationErrorHandlerReturn<T> {
const {
tabMapping,
setActiveTab,
debug = process.env.NODE_ENV === 'development',
toastTitle = 'Please fix form errors',
} = options;
const onValidationError = useCallback(
(errors: FieldErrors<T>) => {
// Log errors in debug mode
if (debug) {
console.error('[Form Validation] Errors:', errors);
}
// Get first error for toast
const firstError = getFirstValidationError(errors);
if (!firstError) return;
// Show toast
toast.error(toastTitle, {
description: `${firstError.field}: ${firstError.message}`,
});
// Navigate to tab if mapping provided
if (tabMapping && setActiveTab) {
const topLevelField = firstError.field.split('.')[0];
const targetTab = tabMapping[topLevelField];
if (targetTab) {
setActiveTab(targetTab);
}
}
},
[tabMapping, setActiveTab, debug, toastTitle]
);
return { onValidationError };
}

View File

@@ -0,0 +1,30 @@
/**
* Form Utilities and Hooks
*
* Centralized exports for form-related utilities.
*
* @module lib/forms
*/
// Utils
export { getFirstValidationError, getAllValidationErrors } from './utils/getFirstValidationError';
export type { ValidationError } from './utils/getFirstValidationError';
export {
safeValue,
isNumber,
isString,
isBoolean,
isArray,
isObject,
deepMergeWithDefaults,
createFormInitializer,
} from './utils/mergeWithDefaults';
// Hooks
export { useValidationErrorHandler } from './hooks/useValidationErrorHandler';
export type {
TabFieldMapping,
UseValidationErrorHandlerOptions,
UseValidationErrorHandlerReturn,
} from './hooks/useValidationErrorHandler';

View File

@@ -0,0 +1,84 @@
/**
* Get First Validation Error
*
* Extracts the first error from react-hook-form FieldErrors,
* including support for nested errors (e.g., model_params.temperature).
*
* @module lib/forms/utils/getFirstValidationError
*/
import type { FieldErrors, FieldValues } from 'react-hook-form';
export interface ValidationError {
/** Field path (e.g., 'name' or 'model_params.temperature') */
field: string;
/** Error message */
message: string;
}
/**
* Recursively extract the first error from FieldErrors
*
* @param errors - FieldErrors object from react-hook-form
* @param prefix - Current field path prefix for nested errors
* @returns First validation error found, or null if no errors
*
* @example
* ```ts
* const errors = { model_params: { temperature: { message: 'Required' } } };
* const error = getFirstValidationError(errors);
* // { field: 'model_params.temperature', message: 'Required' }
* ```
*/
export function getFirstValidationError<T extends FieldValues>(
errors: FieldErrors<T>,
prefix = ''
): ValidationError | null {
for (const key of Object.keys(errors)) {
const error = errors[key as keyof typeof errors];
if (!error || typeof error !== 'object') continue;
const fieldPath = prefix ? `${prefix}.${key}` : key;
// Check if this is a direct error with a message
if ('message' in error && typeof error.message === 'string') {
return { field: fieldPath, message: error.message };
}
// Check if this is a nested object (e.g., model_params.temperature)
const nestedError = getFirstValidationError(error as FieldErrors<FieldValues>, fieldPath);
if (nestedError) return nestedError;
}
return null;
}
/**
* Get all validation errors as a flat array
*
* @param errors - FieldErrors object from react-hook-form
* @param prefix - Current field path prefix for nested errors
* @returns Array of all validation errors
*/
export function getAllValidationErrors<T extends FieldValues>(
errors: FieldErrors<T>,
prefix = ''
): ValidationError[] {
const result: ValidationError[] = [];
for (const key of Object.keys(errors)) {
const error = errors[key as keyof typeof errors];
if (!error || typeof error !== 'object') continue;
const fieldPath = prefix ? `${prefix}.${key}` : key;
if ('message' in error && typeof error.message === 'string') {
result.push({ field: fieldPath, message: error.message });
} else {
// Nested object without message, recurse
result.push(...getAllValidationErrors(error as FieldErrors<FieldValues>, fieldPath));
}
}
return result;
}

View File

@@ -0,0 +1,169 @@
/**
* Merge With Defaults
*
* Utilities for safely merging API data with form defaults.
* Handles missing fields, type mismatches, and nested objects.
*
* @module lib/forms/utils/mergeWithDefaults
*/
/**
* Safely get a value with type checking and default fallback
*
* @param value - Value to check
* @param defaultValue - Default to use if value is invalid
* @param typeCheck - Type checking function
* @returns Valid value or default
*
* @example
* ```ts
* const temp = safeValue(apiData.temperature, 0.7, (v) => typeof v === 'number');
* ```
*/
export function safeValue<T>(
value: unknown,
defaultValue: T,
typeCheck: (v: unknown) => v is T
): T {
return typeCheck(value) ? value : defaultValue;
}
/**
* Type guard for numbers
*/
export function isNumber(v: unknown): v is number {
return typeof v === 'number' && !Number.isNaN(v);
}
/**
* Type guard for strings
*/
export function isString(v: unknown): v is string {
return typeof v === 'string';
}
/**
* Type guard for booleans
*/
export function isBoolean(v: unknown): v is boolean {
return typeof v === 'boolean';
}
/**
* Type guard for arrays
*/
export function isArray<T>(v: unknown, itemCheck?: (item: unknown) => item is T): v is T[] {
if (!Array.isArray(v)) return false;
if (itemCheck) return v.every(itemCheck);
return true;
}
/**
* Type guard for objects (non-null, non-array)
*/
export function isObject(v: unknown): v is Record<string, unknown> {
return typeof v === 'object' && v !== null && !Array.isArray(v);
}
/**
* Deep merge two objects, with source values taking precedence
* Only merges values that pass type checking against defaults
*
* @param defaults - Default values (used as type template)
* @param source - Source values to merge (from API)
* @returns Merged object with all fields from defaults
*
* @example
* ```ts
* const defaults = { temperature: 0.7, max_tokens: 8192, top_p: 0.95 };
* const apiData = { temperature: 0.5 }; // missing max_tokens and top_p
* const merged = deepMergeWithDefaults(defaults, apiData);
* // { temperature: 0.5, max_tokens: 8192, top_p: 0.95 }
* ```
*/
export function deepMergeWithDefaults<T extends Record<string, unknown>>(
defaults: T,
source: Partial<T> | null | undefined
): T {
if (!source) return { ...defaults };
const result = { ...defaults } as T;
for (const key of Object.keys(defaults) as Array<keyof T>) {
const defaultValue = defaults[key];
const sourceValue = source[key];
// Skip if source doesn't have this key
if (!(key in source) || sourceValue === undefined) {
continue;
}
// Handle nested objects recursively
if (isObject(defaultValue) && isObject(sourceValue)) {
result[key] = deepMergeWithDefaults(
defaultValue as Record<string, unknown>,
sourceValue as Record<string, unknown>
) as T[keyof T];
continue;
}
// For primitives and arrays, only use source if types match
if (typeof sourceValue === typeof defaultValue) {
result[key] = sourceValue as T[keyof T];
}
// Special case: default is null but source has a value (nullable fields)
else if (defaultValue === null && sourceValue !== null) {
result[key] = sourceValue as T[keyof T];
}
// Special case: allow null for nullable fields
else if (sourceValue === null && defaultValue === null) {
result[key] = null as T[keyof T];
}
}
return result;
}
/**
* Create a form values initializer from API data
*
* This is a higher-order function that creates a type-safe initializer
* for transforming API responses into form values with defaults.
*
* @param defaults - Default form values
* @param transform - Optional transform function for custom mapping
* @returns Function that takes API data and returns form values
*
* @example
* ```ts
* const initializeAgentForm = createFormInitializer(
* defaultAgentTypeValues,
* (apiData, defaults) => ({
* ...defaults,
* name: apiData?.name ?? defaults.name,
* model_params: deepMergeWithDefaults(
* defaults.model_params,
* apiData?.model_params
* ),
* })
* );
*
* // Usage
* const formValues = initializeAgentForm(apiResponse);
* ```
*/
export function createFormInitializer<TForm, TApi = Partial<TForm>>(
defaults: TForm,
transform?: (apiData: TApi | null | undefined, defaults: TForm) => TForm
): (apiData: TApi | null | undefined) => TForm {
return (apiData) => {
if (transform) {
return transform(apiData, defaults);
}
// Default behavior: deep merge
return deepMergeWithDefaults(
defaults as Record<string, unknown>,
apiData as Record<string, unknown> | null | undefined
) as TForm;
};
}

View File

@@ -6,12 +6,18 @@
*/
import { z } from 'zod';
import type { AgentTypeCategory } from '@/lib/api/types/agentTypes';
/**
* Slug validation regex: lowercase letters, numbers, and hyphens only
*/
const slugRegex = /^[a-z0-9-]+$/;
/**
* Hex color validation regex
*/
const hexColorRegex = /^#[0-9A-Fa-f]{6}$/;
/**
* Available AI models for agent types
*/
@@ -43,6 +49,84 @@ export const AGENT_TYPE_STATUS = [
{ value: false, label: 'Inactive' },
] as const;
/**
* Agent type categories for organizing agents
*/
/* istanbul ignore next -- constant declaration */
export const AGENT_TYPE_CATEGORIES: {
value: AgentTypeCategory;
label: string;
description: string;
}[] = [
{ value: 'development', label: 'Development', description: 'Product, project, and engineering' },
{ value: 'design', label: 'Design', description: 'UI/UX and design research' },
{ value: 'quality', label: 'Quality', description: 'QA and security assurance' },
{ value: 'operations', label: 'Operations', description: 'DevOps and MLOps engineering' },
{ value: 'ai_ml', label: 'AI & ML', description: 'Machine learning specialists' },
{ value: 'data', label: 'Data', description: 'Data science and engineering' },
{ value: 'leadership', label: 'Leadership', description: 'Technical leadership' },
{ value: 'domain_expert', label: 'Domain Experts', description: 'Industry specialists' },
];
/**
* Available Lucide icons for agent types
*/
/* istanbul ignore next -- constant declaration */
export const AVAILABLE_ICONS = [
// Development
{ value: 'clipboard-check', label: 'Clipboard Check', category: 'development' },
{ value: 'briefcase', label: 'Briefcase', category: 'development' },
{ value: 'file-text', label: 'File Text', category: 'development' },
{ value: 'git-branch', label: 'Git Branch', category: 'development' },
{ value: 'code', label: 'Code', category: 'development' },
{ value: 'server', label: 'Server', category: 'development' },
{ value: 'layout', label: 'Layout', category: 'development' },
{ value: 'smartphone', label: 'Smartphone', category: 'development' },
// Design
{ value: 'palette', label: 'Palette', category: 'design' },
{ value: 'search', label: 'Search', category: 'design' },
// Quality
{ value: 'shield', label: 'Shield', category: 'quality' },
{ value: 'shield-check', label: 'Shield Check', category: 'quality' },
// Operations
{ value: 'settings', label: 'Settings', category: 'operations' },
{ value: 'settings-2', label: 'Settings 2', category: 'operations' },
// AI/ML
{ value: 'brain', label: 'Brain', category: 'ai_ml' },
{ value: 'microscope', label: 'Microscope', category: 'ai_ml' },
{ value: 'eye', label: 'Eye', category: 'ai_ml' },
{ value: 'message-square', label: 'Message Square', category: 'ai_ml' },
// Data
{ value: 'bar-chart', label: 'Bar Chart', category: 'data' },
{ value: 'database', label: 'Database', category: 'data' },
// Leadership
{ value: 'users', label: 'Users', category: 'leadership' },
{ value: 'target', label: 'Target', category: 'leadership' },
// Domain Expert
{ value: 'calculator', label: 'Calculator', category: 'domain_expert' },
{ value: 'heart-pulse', label: 'Heart Pulse', category: 'domain_expert' },
{ value: 'flask-conical', label: 'Flask', category: 'domain_expert' },
{ value: 'lightbulb', label: 'Lightbulb', category: 'domain_expert' },
{ value: 'book-open', label: 'Book Open', category: 'domain_expert' },
// Generic
{ value: 'bot', label: 'Bot', category: 'generic' },
] as const;
/**
* Color palette for agent type visual distinction
*/
/* istanbul ignore next -- constant declaration */
export const COLOR_PALETTE = [
{ value: '#3B82F6', label: 'Blue', category: 'development' },
{ value: '#EC4899', label: 'Pink', category: 'design' },
{ value: '#10B981', label: 'Green', category: 'quality' },
{ value: '#F59E0B', label: 'Amber', category: 'operations' },
{ value: '#8B5CF6', label: 'Purple', category: 'ai_ml' },
{ value: '#06B6D4', label: 'Cyan', category: 'data' },
{ value: '#F97316', label: 'Orange', category: 'leadership' },
{ value: '#84CC16', label: 'Lime', category: 'domain_expert' },
] as const;
/**
* Model params schema
*/
@@ -52,6 +136,20 @@ const modelParamsSchema = z.object({
top_p: z.number().min(0).max(1),
});
/**
* Agent type category enum values
*/
const agentTypeCategoryValues = [
'development',
'design',
'quality',
'operations',
'ai_ml',
'data',
'leadership',
'domain_expert',
] as const;
/**
* Schema for agent type form fields
*/
@@ -96,6 +194,23 @@ export const agentTypeFormSchema = z.object({
tool_permissions: z.record(z.string(), z.unknown()),
is_active: z.boolean(),
// Category and display fields
category: z.enum(agentTypeCategoryValues).nullable().optional(),
icon: z.string().max(50, 'Icon must be less than 50 characters').nullable().optional(),
color: z
.string()
.regex(hexColorRegex, 'Color must be a valid hex code (e.g., #3B82F6)')
.nullable()
.optional(),
sort_order: z.number().int().min(0).max(1000),
typical_tasks: z.array(z.string()),
collaboration_hints: z.array(z.string()),
});
/**
@@ -138,6 +253,13 @@ export const defaultAgentTypeValues: AgentTypeCreateFormValues = {
mcp_servers: [],
tool_permissions: {},
is_active: false, // Start as draft
// Category and display fields
category: null,
icon: 'bot',
color: '#3B82F6',
sort_order: 0,
typical_tasks: [],
collaboration_hints: [],
};
/**

View File

@@ -21,6 +21,13 @@ Your approach is:
mcp_servers: ['gitea', 'knowledge', 'filesystem'],
tool_permissions: {},
is_active: true,
// Category and display fields
category: 'development',
icon: 'git-branch',
color: '#3B82F6',
sort_order: 40,
typical_tasks: ['Design system architecture', 'Create ADRs'],
collaboration_hints: ['backend-engineer', 'frontend-engineer'],
created_at: '2025-01-10T00:00:00Z',
updated_at: '2025-01-18T00:00:00Z',
instance_count: 2,
@@ -58,9 +65,8 @@ describe('AgentTypeDetail', () => {
expect(screen.getByText('Inactive')).toBeInTheDocument();
});
it('renders description card', () => {
it('renders description in hero header', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText('Description')).toBeInTheDocument();
expect(
screen.getByText('Designs system architecture and makes technology decisions')
).toBeInTheDocument();
@@ -130,7 +136,7 @@ describe('AgentTypeDetail', () => {
const user = userEvent.setup();
render(<AgentTypeDetail {...defaultProps} />);
await user.click(screen.getByRole('button', { name: /go back/i }));
await user.click(screen.getByRole('button', { name: /back to agent types/i }));
expect(defaultProps.onBack).toHaveBeenCalledTimes(1);
});
@@ -211,4 +217,146 @@ describe('AgentTypeDetail', () => {
);
expect(screen.getByText('None configured')).toBeInTheDocument();
});
describe('Hero Header', () => {
it('renders hero header with agent name', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(
screen.getByRole('heading', { level: 1, name: 'Software Architect' })
).toBeInTheDocument();
});
it('renders dynamic icon in hero header', () => {
const { container } = render(<AgentTypeDetail {...defaultProps} />);
expect(container.querySelector('svg.lucide-git-branch')).toBeInTheDocument();
});
it('applies agent color to hero header gradient', () => {
const { container } = render(<AgentTypeDetail {...defaultProps} />);
const heroHeader = container.querySelector('[style*="linear-gradient"]');
expect(heroHeader).toBeInTheDocument();
});
it('renders category badge in hero header', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText('Development')).toBeInTheDocument();
});
it('shows last updated date in hero header', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText(/Last updated:/)).toBeInTheDocument();
expect(screen.getByText(/Jan 18, 2025/)).toBeInTheDocument();
});
});
describe('Typical Tasks Card', () => {
it('renders "What This Agent Does Best" card', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText('What This Agent Does Best')).toBeInTheDocument();
});
it('displays all typical tasks', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText('Design system architecture')).toBeInTheDocument();
expect(screen.getByText('Create ADRs')).toBeInTheDocument();
});
it('does not render typical tasks card when empty', () => {
render(
<AgentTypeDetail {...defaultProps} agentType={{ ...mockAgentType, typical_tasks: [] }} />
);
expect(screen.queryByText('What This Agent Does Best')).not.toBeInTheDocument();
});
});
describe('Collaboration Hints Card', () => {
it('renders "Works Well With" card', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText('Works Well With')).toBeInTheDocument();
});
it('displays collaboration hints as badges', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText('backend-engineer')).toBeInTheDocument();
expect(screen.getByText('frontend-engineer')).toBeInTheDocument();
});
it('does not render collaboration hints card when empty', () => {
render(
<AgentTypeDetail
{...defaultProps}
agentType={{ ...mockAgentType, collaboration_hints: [] }}
/>
);
expect(screen.queryByText('Works Well With')).not.toBeInTheDocument();
});
});
describe('Category Badge', () => {
it('renders category badge with correct label', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText('Development')).toBeInTheDocument();
});
it('does not render category badge when category is null', () => {
render(
<AgentTypeDetail {...defaultProps} agentType={{ ...mockAgentType, category: null }} />
);
// Should not have a Development badge in the hero header area
// The word "Development" should not appear
expect(screen.queryByText('Development')).not.toBeInTheDocument();
});
});
describe('Details Card', () => {
it('renders details card with slug', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText('Slug')).toBeInTheDocument();
expect(screen.getByText('software-architect')).toBeInTheDocument();
});
it('renders details card with sort order', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText('Sort Order')).toBeInTheDocument();
expect(screen.getByText('40')).toBeInTheDocument();
});
it('renders details card with creation date', () => {
render(<AgentTypeDetail {...defaultProps} />);
expect(screen.getByText('Created')).toBeInTheDocument();
expect(screen.getByText(/Jan 10, 2025/)).toBeInTheDocument();
});
});
describe('Dynamic Icon', () => {
it('renders fallback icon when icon is null', () => {
const { container } = render(
<AgentTypeDetail {...defaultProps} agentType={{ ...mockAgentType, icon: null }} />
);
// Should fall back to 'bot' icon
expect(container.querySelector('svg.lucide-bot')).toBeInTheDocument();
});
it('renders correct icon based on agent type', () => {
const agentWithBrainIcon = { ...mockAgentType, icon: 'brain' };
const { container } = render(
<AgentTypeDetail {...defaultProps} agentType={agentWithBrainIcon} />
);
expect(container.querySelector('svg.lucide-brain')).toBeInTheDocument();
});
});
describe('Color Styling', () => {
it('applies custom color to instance count', () => {
render(<AgentTypeDetail {...defaultProps} />);
const instanceCount = screen.getByText('2');
expect(instanceCount).toHaveStyle({ color: 'rgb(59, 130, 246)' });
});
it('uses default color when color is null', () => {
render(<AgentTypeDetail {...defaultProps} agentType={{ ...mockAgentType, color: null }} />);
// Should still render without errors
expect(screen.getByText('Software Architect')).toBeInTheDocument();
});
});
});

View File

@@ -16,6 +16,13 @@ const mockAgentType: AgentTypeResponse = {
mcp_servers: ['gitea'],
tool_permissions: {},
is_active: true,
// Category and display fields
category: 'development',
icon: 'git-branch',
color: '#3B82F6',
sort_order: 40,
typical_tasks: ['Design system architecture'],
collaboration_hints: ['backend-engineer'],
created_at: '2025-01-10T00:00:00Z',
updated_at: '2025-01-18T00:00:00Z',
instance_count: 2,
@@ -192,7 +199,8 @@ describe('AgentTypeForm', () => {
const expertiseInput = screen.getByPlaceholderText(/e.g., system design/i);
await user.type(expertiseInput, 'new skill');
await user.click(screen.getByRole('button', { name: /^add$/i }));
// Click the first "Add" button (for expertise)
await user.click(screen.getAllByRole('button', { name: /^add$/i })[0]);
expect(screen.getByText('new skill')).toBeInTheDocument();
});
@@ -454,7 +462,8 @@ describe('AgentTypeForm', () => {
// Agent type already has 'system design'
const expertiseInput = screen.getByPlaceholderText(/e.g., system design/i);
await user.type(expertiseInput, 'system design');
await user.click(screen.getByRole('button', { name: /^add$/i }));
// Click the first "Add" button (for expertise)
await user.click(screen.getAllByRole('button', { name: /^add$/i })[0]);
// Should still only have one 'system design' badge
const badges = screen.getAllByText('system design');
@@ -465,7 +474,8 @@ describe('AgentTypeForm', () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} />);
const addButton = screen.getByRole('button', { name: /^add$/i });
// Click the first "Add" button (for expertise)
const addButton = screen.getAllByRole('button', { name: /^add$/i })[0];
await user.click(addButton);
// No badges should be added
@@ -478,7 +488,8 @@ describe('AgentTypeForm', () => {
const expertiseInput = screen.getByPlaceholderText(/e.g., system design/i);
await user.type(expertiseInput, 'API Design');
await user.click(screen.getByRole('button', { name: /^add$/i }));
// Click the first "Add" button (for expertise)
await user.click(screen.getAllByRole('button', { name: /^add$/i })[0]);
expect(screen.getByText('api design')).toBeInTheDocument();
});
@@ -489,7 +500,8 @@ describe('AgentTypeForm', () => {
const expertiseInput = screen.getByPlaceholderText(/e.g., system design/i);
await user.type(expertiseInput, ' testing ');
await user.click(screen.getByRole('button', { name: /^add$/i }));
// Click the first "Add" button (for expertise)
await user.click(screen.getAllByRole('button', { name: /^add$/i })[0]);
expect(screen.getByText('testing')).toBeInTheDocument();
});
@@ -502,7 +514,8 @@ describe('AgentTypeForm', () => {
/e.g., system design/i
) as HTMLInputElement;
await user.type(expertiseInput, 'new skill');
await user.click(screen.getByRole('button', { name: /^add$/i }));
// Click the first "Add" button (for expertise)
await user.click(screen.getAllByRole('button', { name: /^add$/i })[0]);
expect(expertiseInput.value).toBe('');
});
@@ -562,4 +575,213 @@ describe('AgentTypeForm', () => {
expect(screen.getByText('Edit Agent Type')).toBeInTheDocument();
});
});
describe('Category & Display Fields', () => {
it('renders category and display section', () => {
render(<AgentTypeForm {...defaultProps} />);
expect(screen.getByText('Category & Display')).toBeInTheDocument();
});
it('shows category select', () => {
render(<AgentTypeForm {...defaultProps} />);
expect(screen.getByLabelText(/category/i)).toBeInTheDocument();
});
it('shows sort order input', () => {
render(<AgentTypeForm {...defaultProps} />);
expect(screen.getByLabelText(/sort order/i)).toBeInTheDocument();
});
it('shows icon input', () => {
render(<AgentTypeForm {...defaultProps} />);
expect(screen.getByLabelText(/icon/i)).toBeInTheDocument();
});
it('shows color input', () => {
render(<AgentTypeForm {...defaultProps} />);
expect(screen.getByLabelText(/color/i)).toBeInTheDocument();
});
it('pre-fills category fields in edit mode', () => {
render(<AgentTypeForm {...defaultProps} agentType={mockAgentType} />);
const iconInput = screen.getByLabelText(/icon/i) as HTMLInputElement;
expect(iconInput.value).toBe('git-branch');
const sortOrderInput = screen.getByLabelText(/sort order/i) as HTMLInputElement;
expect(sortOrderInput.value).toBe('40');
});
});
describe('Typical Tasks Management', () => {
it('shows typical tasks section', () => {
render(<AgentTypeForm {...defaultProps} />);
expect(screen.getByText('Typical Tasks')).toBeInTheDocument();
});
it('adds typical task when add button is clicked', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} />);
const taskInput = screen.getByPlaceholderText(/e.g., design system architecture/i);
await user.type(taskInput, 'Write documentation');
// Click the second "Add" button (for typical tasks)
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
await user.click(addButtons[1]);
expect(screen.getByText('Write documentation')).toBeInTheDocument();
});
it('adds typical task on enter key', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} />);
const taskInput = screen.getByPlaceholderText(/e.g., design system architecture/i);
await user.type(taskInput, 'Write documentation{Enter}');
expect(screen.getByText('Write documentation')).toBeInTheDocument();
});
it('removes typical task when X button is clicked', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} agentType={mockAgentType} />);
// Should have existing typical task
expect(screen.getByText('Design system architecture')).toBeInTheDocument();
// Click remove button
const removeButton = screen.getByRole('button', {
name: /remove design system architecture/i,
});
await user.click(removeButton);
expect(screen.queryByText('Design system architecture')).not.toBeInTheDocument();
});
it('does not add duplicate typical tasks', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} agentType={mockAgentType} />);
// Agent type already has 'Design system architecture'
const taskInput = screen.getByPlaceholderText(/e.g., design system architecture/i);
await user.type(taskInput, 'Design system architecture');
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
await user.click(addButtons[1]);
// Should still only have one badge
const badges = screen.getAllByText('Design system architecture');
expect(badges).toHaveLength(1);
});
it('does not add empty typical task', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} />);
// Click the second "Add" button (for typical tasks) without typing
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
await user.click(addButtons[1]);
// No badges should be added (check that there's no remove button for typical tasks)
expect(
screen.queryByRole('button', { name: /remove write documentation/i })
).not.toBeInTheDocument();
});
});
describe('Collaboration Hints Management', () => {
it('shows collaboration hints section', () => {
render(<AgentTypeForm {...defaultProps} />);
expect(screen.getByText('Collaboration Hints')).toBeInTheDocument();
});
it('adds collaboration hint when add button is clicked', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} />);
const hintInput = screen.getByPlaceholderText(/e.g., backend-engineer/i);
await user.type(hintInput, 'devops-engineer');
// Click the third "Add" button (for collaboration hints)
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
await user.click(addButtons[2]);
expect(screen.getByText('devops-engineer')).toBeInTheDocument();
});
it('adds collaboration hint on enter key', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} />);
const hintInput = screen.getByPlaceholderText(/e.g., backend-engineer/i);
await user.type(hintInput, 'devops-engineer{Enter}');
expect(screen.getByText('devops-engineer')).toBeInTheDocument();
});
it('removes collaboration hint when X button is clicked', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} agentType={mockAgentType} />);
// Should have existing collaboration hint
expect(screen.getByText('backend-engineer')).toBeInTheDocument();
// Click remove button
const removeButton = screen.getByRole('button', { name: /remove backend-engineer/i });
await user.click(removeButton);
expect(screen.queryByText('backend-engineer')).not.toBeInTheDocument();
});
it('converts collaboration hints to lowercase', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} />);
const hintInput = screen.getByPlaceholderText(/e.g., backend-engineer/i);
await user.type(hintInput, 'DevOps-Engineer');
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
await user.click(addButtons[2]);
expect(screen.getByText('devops-engineer')).toBeInTheDocument();
});
it('does not add duplicate collaboration hints', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} agentType={mockAgentType} />);
// Agent type already has 'backend-engineer'
const hintInput = screen.getByPlaceholderText(/e.g., backend-engineer/i);
await user.type(hintInput, 'backend-engineer');
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
await user.click(addButtons[2]);
// Should still only have one badge
const badges = screen.getAllByText('backend-engineer');
expect(badges).toHaveLength(1);
});
it('does not add empty collaboration hint', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} />);
// Click the third "Add" button (for collaboration hints) without typing
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
await user.click(addButtons[2]);
// No badges should be added
expect(
screen.queryByRole('button', { name: /remove devops-engineer/i })
).not.toBeInTheDocument();
});
it('clears input after adding collaboration hint', async () => {
const user = userEvent.setup();
render(<AgentTypeForm {...defaultProps} />);
const hintInput = screen.getByPlaceholderText(/e.g., backend-engineer/i) as HTMLInputElement;
await user.type(hintInput, 'devops-engineer');
const addButtons = screen.getAllByRole('button', { name: /^add$/i });
await user.click(addButtons[2]);
expect(hintInput.value).toBe('');
});
});
});

View File

@@ -17,6 +17,13 @@ const mockAgentTypes: AgentTypeResponse[] = [
mcp_servers: ['gitea', 'knowledge'],
tool_permissions: {},
is_active: true,
// Category and display fields
category: 'development',
icon: 'clipboard-check',
color: '#3B82F6',
sort_order: 10,
typical_tasks: ['Manage backlog', 'Write user stories'],
collaboration_hints: ['business-analyst', 'scrum-master'],
created_at: '2025-01-15T00:00:00Z',
updated_at: '2025-01-20T00:00:00Z',
instance_count: 3,
@@ -34,6 +41,13 @@ const mockAgentTypes: AgentTypeResponse[] = [
mcp_servers: ['gitea'],
tool_permissions: {},
is_active: false,
// Category and display fields
category: 'development',
icon: 'git-branch',
color: '#3B82F6',
sort_order: 40,
typical_tasks: ['Design architecture', 'Create ADRs'],
collaboration_hints: ['backend-engineer', 'devops-engineer'],
created_at: '2025-01-10T00:00:00Z',
updated_at: '2025-01-18T00:00:00Z',
instance_count: 0,
@@ -48,6 +62,10 @@ describe('AgentTypeList', () => {
onSearchChange: jest.fn(),
statusFilter: 'all',
onStatusFilterChange: jest.fn(),
categoryFilter: 'all',
onCategoryFilterChange: jest.fn(),
viewMode: 'grid' as const,
onViewModeChange: jest.fn(),
onSelect: jest.fn(),
onCreate: jest.fn(),
};
@@ -194,4 +212,158 @@ describe('AgentTypeList', () => {
const { container } = render(<AgentTypeList {...defaultProps} className="custom-class" />);
expect(container.firstChild).toHaveClass('custom-class');
});
describe('Category Filter', () => {
it('renders category filter dropdown', () => {
render(<AgentTypeList {...defaultProps} />);
expect(screen.getByRole('combobox', { name: /filter by category/i })).toBeInTheDocument();
});
it('shows "All Categories" as default option', () => {
render(<AgentTypeList {...defaultProps} categoryFilter="all" />);
expect(screen.getByText('All Categories')).toBeInTheDocument();
});
it('displays category badge on agent cards', () => {
render(<AgentTypeList {...defaultProps} />);
// Both agents have 'development' category
const developmentBadges = screen.getAllByText('Development');
expect(developmentBadges.length).toBe(2);
});
it('shows filter hint in empty state when category filter is applied', () => {
render(<AgentTypeList {...defaultProps} agentTypes={[]} categoryFilter="design" />);
expect(screen.getByText('Try adjusting your search or filters')).toBeInTheDocument();
});
});
describe('View Mode Toggle', () => {
it('renders view mode toggle buttons', () => {
render(<AgentTypeList {...defaultProps} />);
expect(screen.getByRole('radio', { name: /grid view/i })).toBeInTheDocument();
expect(screen.getByRole('radio', { name: /list view/i })).toBeInTheDocument();
});
it('renders grid view by default', () => {
const { container } = render(<AgentTypeList {...defaultProps} viewMode="grid" />);
// Grid view uses CSS grid
expect(container.querySelector('.grid')).toBeInTheDocument();
});
it('renders list view when viewMode is list', () => {
const { container } = render(<AgentTypeList {...defaultProps} viewMode="list" />);
// List view uses space-y-3 for vertical stacking
expect(container.querySelector('.space-y-3')).toBeInTheDocument();
});
it('calls onViewModeChange when grid toggle is clicked', async () => {
const user = userEvent.setup();
const onViewModeChange = jest.fn();
render(
<AgentTypeList {...defaultProps} viewMode="list" onViewModeChange={onViewModeChange} />
);
await user.click(screen.getByRole('radio', { name: /grid view/i }));
expect(onViewModeChange).toHaveBeenCalledWith('grid');
});
it('calls onViewModeChange when list toggle is clicked', async () => {
const user = userEvent.setup();
const onViewModeChange = jest.fn();
render(
<AgentTypeList {...defaultProps} viewMode="grid" onViewModeChange={onViewModeChange} />
);
await user.click(screen.getByRole('radio', { name: /list view/i }));
expect(onViewModeChange).toHaveBeenCalledWith('list');
});
it('shows list-specific loading skeletons when viewMode is list', () => {
const { container } = render(
<AgentTypeList {...defaultProps} agentTypes={[]} isLoading={true} viewMode="list" />
);
expect(container.querySelectorAll('.animate-pulse').length).toBeGreaterThan(0);
});
});
describe('List View', () => {
it('shows agent info in list rows', () => {
render(<AgentTypeList {...defaultProps} viewMode="list" />);
expect(screen.getByText('Product Owner')).toBeInTheDocument();
expect(screen.getByText('Software Architect')).toBeInTheDocument();
});
it('shows category badge in list view', () => {
render(<AgentTypeList {...defaultProps} viewMode="list" />);
const developmentBadges = screen.getAllByText('Development');
expect(developmentBadges.length).toBe(2);
});
it('shows expertise count in list view', () => {
render(<AgentTypeList {...defaultProps} viewMode="list" />);
// Both agents have 3 expertise areas
const expertiseTexts = screen.getAllByText('3 expertise areas');
expect(expertiseTexts.length).toBe(2);
});
it('calls onSelect when list row is clicked', async () => {
const user = userEvent.setup();
const onSelect = jest.fn();
render(<AgentTypeList {...defaultProps} viewMode="list" onSelect={onSelect} />);
await user.click(screen.getByText('Product Owner'));
expect(onSelect).toHaveBeenCalledWith('type-001');
});
it('supports keyboard navigation on list rows', async () => {
const user = userEvent.setup();
const onSelect = jest.fn();
render(<AgentTypeList {...defaultProps} viewMode="list" onSelect={onSelect} />);
const rows = screen.getAllByRole('button', { name: /view .* agent type/i });
rows[0].focus();
await user.keyboard('{Enter}');
expect(onSelect).toHaveBeenCalledWith('type-001');
});
});
describe('Dynamic Icons', () => {
it('renders agent icon in grid view', () => {
const { container } = render(<AgentTypeList {...defaultProps} viewMode="grid" />);
// Check for svg icons with lucide classes
const icons = container.querySelectorAll('svg.lucide-clipboard-check, svg.lucide-git-branch');
expect(icons.length).toBeGreaterThan(0);
});
it('renders agent icon in list view', () => {
const { container } = render(<AgentTypeList {...defaultProps} viewMode="list" />);
const icons = container.querySelectorAll('svg.lucide-clipboard-check, svg.lucide-git-branch');
expect(icons.length).toBeGreaterThan(0);
});
});
describe('Color Accent', () => {
it('applies color to card border in grid view', () => {
const { container } = render(<AgentTypeList {...defaultProps} viewMode="grid" />);
const card = container.querySelector('[style*="border-top-color"]');
expect(card).toBeInTheDocument();
});
it('applies color to row border in list view', () => {
const { container } = render(<AgentTypeList {...defaultProps} viewMode="list" />);
const row = container.querySelector('[style*="border-left-color"]');
expect(row).toBeInTheDocument();
});
});
describe('Category Badge Component', () => {
it('does not render category badge when category is null', () => {
const agentWithNoCategory: AgentTypeResponse = {
...mockAgentTypes[0],
category: null,
};
render(<AgentTypeList {...defaultProps} agentTypes={[agentWithNoCategory]} />);
expect(screen.queryByText('Development')).not.toBeInTheDocument();
});
});
});

View File

@@ -0,0 +1,448 @@
/**
* Tests for FormSelect Component
* Verifies select field rendering, accessibility, and error handling
*/
import React from 'react';
import { render, screen, fireEvent, waitFor } from '@testing-library/react';
import { useForm, FormProvider } from 'react-hook-form';
import { FormSelect, type SelectOption } from '@/components/forms/FormSelect';
// Polyfill for Radix UI Select - jsdom doesn't support these browser APIs
beforeAll(() => {
Element.prototype.hasPointerCapture = jest.fn(() => false);
Element.prototype.setPointerCapture = jest.fn();
Element.prototype.releasePointerCapture = jest.fn();
Element.prototype.scrollIntoView = jest.fn();
window.HTMLElement.prototype.scrollIntoView = jest.fn();
});
// Helper wrapper component to provide form context
interface TestFormValues {
model: string;
category: string;
}
function TestWrapper({
children,
defaultValues = { model: '', category: '' },
}: {
children: (props: {
control: ReturnType<typeof useForm<TestFormValues>>['control'];
}) => React.ReactNode;
defaultValues?: Partial<TestFormValues>;
}) {
const form = useForm<TestFormValues>({
defaultValues: { model: '', category: '', ...defaultValues },
});
return <FormProvider {...form}>{children({ control: form.control })}</FormProvider>;
}
const mockOptions: SelectOption[] = [
{ value: 'claude-opus', label: 'Claude Opus' },
{ value: 'claude-sonnet', label: 'Claude Sonnet' },
{ value: 'claude-haiku', label: 'Claude Haiku' },
];
describe('FormSelect', () => {
describe('Basic Rendering', () => {
it('renders with label and select trigger', () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
/>
)}
</TestWrapper>
);
expect(screen.getByText('Primary Model')).toBeInTheDocument();
expect(screen.getByRole('combobox')).toBeInTheDocument();
});
it('renders with description', () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
description="Main model used for this agent"
/>
)}
</TestWrapper>
);
expect(screen.getByText('Main model used for this agent')).toBeInTheDocument();
});
it('renders with custom placeholder', () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
placeholder="Choose a model"
/>
)}
</TestWrapper>
);
expect(screen.getByText('Choose a model')).toBeInTheDocument();
});
it('renders default placeholder when none provided', () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
/>
)}
</TestWrapper>
);
expect(screen.getByText('Select primary model')).toBeInTheDocument();
});
});
describe('Required Field', () => {
it('shows asterisk when required is true', () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
required
/>
)}
</TestWrapper>
);
expect(screen.getByText('*')).toBeInTheDocument();
});
it('does not show asterisk when required is false', () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
required={false}
/>
)}
</TestWrapper>
);
expect(screen.queryByText('*')).not.toBeInTheDocument();
});
});
describe('Options Rendering', () => {
it('renders all options when opened', async () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
/>
)}
</TestWrapper>
);
// Open the select using fireEvent (works better with Radix UI)
fireEvent.click(screen.getByRole('combobox'));
// Check all options are rendered
await waitFor(() => {
expect(screen.getByRole('option', { name: 'Claude Opus' })).toBeInTheDocument();
});
expect(screen.getByRole('option', { name: 'Claude Sonnet' })).toBeInTheDocument();
expect(screen.getByRole('option', { name: 'Claude Haiku' })).toBeInTheDocument();
});
it('selects option when clicked', async () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
/>
)}
</TestWrapper>
);
// Open the select and choose an option
fireEvent.click(screen.getByRole('combobox'));
await waitFor(() => {
expect(screen.getByRole('option', { name: 'Claude Sonnet' })).toBeInTheDocument();
});
fireEvent.click(screen.getByRole('option', { name: 'Claude Sonnet' }));
// The selected value should now be displayed
await waitFor(() => {
expect(screen.getByRole('combobox')).toHaveTextContent('Claude Sonnet');
});
});
});
describe('Disabled State', () => {
it('disables select when disabled prop is true', () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
disabled
/>
)}
</TestWrapper>
);
expect(screen.getByRole('combobox')).toBeDisabled();
});
it('enables select when disabled prop is false', () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
disabled={false}
/>
)}
</TestWrapper>
);
expect(screen.getByRole('combobox')).not.toBeDisabled();
});
});
describe('Pre-selected Value', () => {
it('displays pre-selected value', () => {
render(
<TestWrapper defaultValues={{ model: 'claude-opus' }}>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
/>
)}
</TestWrapper>
);
expect(screen.getByRole('combobox')).toHaveTextContent('Claude Opus');
});
});
describe('Accessibility', () => {
it('links label to select via htmlFor/id', () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
/>
)}
</TestWrapper>
);
const label = screen.getByText('Primary Model');
const select = screen.getByRole('combobox');
expect(label).toHaveAttribute('for', 'model');
expect(select).toHaveAttribute('id', 'model');
});
it('sets aria-describedby with description ID when description exists', () => {
render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
description="Choose the main model"
/>
)}
</TestWrapper>
);
const select = screen.getByRole('combobox');
expect(select).toHaveAttribute('aria-describedby', 'model-description');
});
});
describe('Custom ClassName', () => {
it('applies custom className to wrapper', () => {
const { container } = render(
<TestWrapper>
{({ control }) => (
<FormSelect
name="model"
control={control}
label="Primary Model"
options={mockOptions}
className="custom-class"
/>
)}
</TestWrapper>
);
expect(container.querySelector('.custom-class')).toBeInTheDocument();
});
});
describe('Error Handling', () => {
it('displays error message when field has error', () => {
function TestComponent() {
const form = useForm<TestFormValues>({
defaultValues: { model: '', category: '' },
});
React.useEffect(() => {
form.setError('model', { type: 'required', message: 'Model is required' });
}, [form]);
return (
<FormProvider {...form}>
<FormSelect
name="model"
control={form.control}
label="Primary Model"
options={mockOptions}
/>
</FormProvider>
);
}
render(<TestComponent />);
expect(screen.getByRole('alert')).toHaveTextContent('Model is required');
});
it('sets aria-invalid when error exists', () => {
function TestComponent() {
const form = useForm<TestFormValues>({
defaultValues: { model: '', category: '' },
});
React.useEffect(() => {
form.setError('model', { type: 'required', message: 'Model is required' });
}, [form]);
return (
<FormProvider {...form}>
<FormSelect
name="model"
control={form.control}
label="Primary Model"
options={mockOptions}
/>
</FormProvider>
);
}
render(<TestComponent />);
expect(screen.getByRole('combobox')).toHaveAttribute('aria-invalid', 'true');
});
it('sets aria-describedby with error ID when error exists', () => {
function TestComponent() {
const form = useForm<TestFormValues>({
defaultValues: { model: '', category: '' },
});
React.useEffect(() => {
form.setError('model', { type: 'required', message: 'Model is required' });
}, [form]);
return (
<FormProvider {...form}>
<FormSelect
name="model"
control={form.control}
label="Primary Model"
options={mockOptions}
/>
</FormProvider>
);
}
render(<TestComponent />);
expect(screen.getByRole('combobox')).toHaveAttribute('aria-describedby', 'model-error');
});
it('combines error and description IDs in aria-describedby', () => {
function TestComponent() {
const form = useForm<TestFormValues>({
defaultValues: { model: '', category: '' },
});
React.useEffect(() => {
form.setError('model', { type: 'required', message: 'Model is required' });
}, [form]);
return (
<FormProvider {...form}>
<FormSelect
name="model"
control={form.control}
label="Primary Model"
options={mockOptions}
description="Choose the main model"
/>
</FormProvider>
);
}
render(<TestComponent />);
expect(screen.getByRole('combobox')).toHaveAttribute(
'aria-describedby',
'model-error model-description'
);
});
});
});

View File

@@ -0,0 +1,281 @@
/**
* Tests for FormTextarea Component
* Verifies textarea field rendering, accessibility, and error handling
*/
import { render, screen } from '@testing-library/react';
import { FormTextarea } from '@/components/forms/FormTextarea';
import type { FieldError } from 'react-hook-form';
describe('FormTextarea', () => {
describe('Basic Rendering', () => {
it('renders with label and textarea', () => {
render(<FormTextarea label="Description" name="description" />);
expect(screen.getByLabelText('Description')).toBeInTheDocument();
expect(screen.getByRole('textbox')).toBeInTheDocument();
});
it('renders with description', () => {
render(
<FormTextarea
label="Personality Prompt"
name="personality"
description="Define the agent's personality and behavior"
/>
);
expect(screen.getByText("Define the agent's personality and behavior")).toBeInTheDocument();
});
it('renders description before textarea', () => {
const { container } = render(
<FormTextarea label="Description" name="description" description="Helper text" />
);
const description = container.querySelector('#description-description');
const textarea = container.querySelector('textarea');
// Get positions
const descriptionRect = description?.getBoundingClientRect();
const textareaRect = textarea?.getBoundingClientRect();
// Description should appear (both should exist)
expect(description).toBeInTheDocument();
expect(textarea).toBeInTheDocument();
// In the DOM order, description comes before textarea
expect(descriptionRect).toBeDefined();
expect(textareaRect).toBeDefined();
});
});
describe('Required Field', () => {
it('shows asterisk when required is true', () => {
render(<FormTextarea label="Description" name="description" required />);
expect(screen.getByText('*')).toBeInTheDocument();
});
it('does not show asterisk when required is false', () => {
render(<FormTextarea label="Description" name="description" required={false} />);
expect(screen.queryByText('*')).not.toBeInTheDocument();
});
});
describe('Error Handling', () => {
it('displays error message when error prop is provided', () => {
const error: FieldError = {
type: 'required',
message: 'Description is required',
};
render(<FormTextarea label="Description" name="description" error={error} />);
expect(screen.getByText('Description is required')).toBeInTheDocument();
});
it('sets aria-invalid when error exists', () => {
const error: FieldError = {
type: 'required',
message: 'Description is required',
};
render(<FormTextarea label="Description" name="description" error={error} />);
const textarea = screen.getByRole('textbox');
expect(textarea).toHaveAttribute('aria-invalid', 'true');
});
it('sets aria-describedby with error ID when error exists', () => {
const error: FieldError = {
type: 'required',
message: 'Description is required',
};
render(<FormTextarea label="Description" name="description" error={error} />);
const textarea = screen.getByRole('textbox');
expect(textarea).toHaveAttribute('aria-describedby', 'description-error');
});
it('renders error with role="alert"', () => {
const error: FieldError = {
type: 'required',
message: 'Description is required',
};
render(<FormTextarea label="Description" name="description" error={error} />);
const errorElement = screen.getByRole('alert');
expect(errorElement).toHaveTextContent('Description is required');
});
});
describe('Accessibility', () => {
it('links label to textarea via htmlFor/id', () => {
render(<FormTextarea label="Description" name="description" />);
const label = screen.getByText('Description');
const textarea = screen.getByRole('textbox');
expect(label).toHaveAttribute('for', 'description');
expect(textarea).toHaveAttribute('id', 'description');
});
it('sets aria-describedby with description ID when description exists', () => {
render(
<FormTextarea
label="Description"
name="description"
description="Enter a detailed description"
/>
);
const textarea = screen.getByRole('textbox');
expect(textarea).toHaveAttribute('aria-describedby', 'description-description');
});
it('combines error and description IDs in aria-describedby', () => {
const error: FieldError = {
type: 'required',
message: 'Description is required',
};
render(
<FormTextarea
label="Description"
name="description"
description="Enter a detailed description"
error={error}
/>
);
const textarea = screen.getByRole('textbox');
expect(textarea).toHaveAttribute(
'aria-describedby',
'description-error description-description'
);
});
});
describe('Textarea Props Forwarding', () => {
it('forwards textarea props correctly', () => {
render(
<FormTextarea
label="Description"
name="description"
placeholder="Enter description"
rows={5}
disabled
/>
);
const textarea = screen.getByRole('textbox');
expect(textarea).toHaveAttribute('placeholder', 'Enter description');
expect(textarea).toHaveAttribute('rows', '5');
expect(textarea).toBeDisabled();
});
it('accepts register() props via registration', () => {
const registerProps = {
name: 'description',
onChange: jest.fn(),
onBlur: jest.fn(),
ref: jest.fn(),
};
render(<FormTextarea label="Description" registration={registerProps} />);
const textarea = screen.getByRole('textbox');
expect(textarea).toBeInTheDocument();
expect(textarea).toHaveAttribute('id', 'description');
});
it('extracts name from spread props', () => {
const spreadProps = {
name: 'content',
onChange: jest.fn(),
};
render(<FormTextarea label="Content" {...spreadProps} />);
const textarea = screen.getByRole('textbox');
expect(textarea).toHaveAttribute('id', 'content');
});
});
describe('Error Cases', () => {
it('throws error when name is not provided', () => {
// Suppress console.error for this test
const consoleError = jest.spyOn(console, 'error').mockImplementation(() => {});
expect(() => {
render(<FormTextarea label="Description" />);
}).toThrow('FormTextarea: name must be provided either explicitly or via register()');
consoleError.mockRestore();
});
});
describe('Layout and Styling', () => {
it('applies correct spacing classes', () => {
const { container } = render(<FormTextarea label="Description" name="description" />);
const wrapper = container.firstChild as HTMLElement;
expect(wrapper).toHaveClass('space-y-2');
});
it('applies correct error styling', () => {
const error: FieldError = {
type: 'required',
message: 'Description is required',
};
render(<FormTextarea label="Description" name="description" error={error} />);
const errorElement = screen.getByRole('alert');
expect(errorElement).toHaveClass('text-sm', 'text-destructive');
});
it('applies correct description styling', () => {
const { container } = render(
<FormTextarea label="Description" name="description" description="Helper text" />
);
const description = container.querySelector('#description-description');
expect(description).toHaveClass('text-sm', 'text-muted-foreground');
});
});
describe('Name Priority', () => {
it('uses explicit name over registration name', () => {
const registerProps = {
name: 'fromRegister',
onChange: jest.fn(),
onBlur: jest.fn(),
ref: jest.fn(),
};
render(<FormTextarea label="Content" name="explicit" registration={registerProps} />);
const textarea = screen.getByRole('textbox');
expect(textarea).toHaveAttribute('id', 'explicit');
});
it('uses registration name when explicit name not provided', () => {
const registerProps = {
name: 'fromRegister',
onChange: jest.fn(),
onBlur: jest.fn(),
ref: jest.fn(),
};
render(<FormTextarea label="Content" registration={registerProps} />);
const textarea = screen.getByRole('textbox');
expect(textarea).toHaveAttribute('id', 'fromRegister');
});
});
});

View File

@@ -0,0 +1,158 @@
/**
* Tests for DynamicIcon Component
* Verifies dynamic icon rendering by name string
*/
import { render, screen } from '@testing-library/react';
import { DynamicIcon, getAvailableIconNames } from '@/components/ui/dynamic-icon';
describe('DynamicIcon', () => {
describe('Basic Rendering', () => {
it('renders an icon by name', () => {
render(<DynamicIcon name="bot" data-testid="icon" />);
const icon = screen.getByTestId('icon');
expect(icon).toBeInTheDocument();
expect(icon.tagName).toBe('svg');
});
it('renders different icons by name', () => {
const { rerender } = render(<DynamicIcon name="code" data-testid="icon" />);
expect(screen.getByTestId('icon')).toHaveClass('lucide-code');
rerender(<DynamicIcon name="brain" data-testid="icon" />);
expect(screen.getByTestId('icon')).toHaveClass('lucide-brain');
rerender(<DynamicIcon name="shield" data-testid="icon" />);
expect(screen.getByTestId('icon')).toHaveClass('lucide-shield');
});
it('renders kebab-case icon names correctly', () => {
render(<DynamicIcon name="clipboard-check" data-testid="icon" />);
expect(screen.getByTestId('icon')).toHaveClass('lucide-clipboard-check');
});
});
describe('Fallback Behavior', () => {
it('renders fallback icon when name is null', () => {
render(<DynamicIcon name={null} data-testid="icon" />);
expect(screen.getByTestId('icon')).toHaveClass('lucide-bot');
});
it('renders fallback icon when name is undefined', () => {
render(<DynamicIcon name={undefined} data-testid="icon" />);
expect(screen.getByTestId('icon')).toHaveClass('lucide-bot');
});
it('renders fallback icon when name is not found', () => {
render(<DynamicIcon name="nonexistent-icon" data-testid="icon" />);
expect(screen.getByTestId('icon')).toHaveClass('lucide-bot');
});
it('uses custom fallback when specified', () => {
render(<DynamicIcon name={null} fallback="code" data-testid="icon" />);
expect(screen.getByTestId('icon')).toHaveClass('lucide-code');
});
it('falls back to bot when custom fallback is also invalid', () => {
render(<DynamicIcon name="invalid" fallback="also-invalid" data-testid="icon" />);
expect(screen.getByTestId('icon')).toHaveClass('lucide-bot');
});
});
describe('Props Forwarding', () => {
it('forwards className to icon', () => {
render(<DynamicIcon name="bot" className="h-5 w-5 text-primary" data-testid="icon" />);
const icon = screen.getByTestId('icon');
expect(icon).toHaveClass('h-5');
expect(icon).toHaveClass('w-5');
expect(icon).toHaveClass('text-primary');
});
it('forwards style to icon', () => {
render(<DynamicIcon name="bot" style={{ color: 'red' }} data-testid="icon" />);
const icon = screen.getByTestId('icon');
expect(icon).toHaveStyle({ color: 'rgb(255, 0, 0)' });
});
it('forwards aria-hidden to icon', () => {
render(<DynamicIcon name="bot" aria-hidden="true" data-testid="icon" />);
const icon = screen.getByTestId('icon');
expect(icon).toHaveAttribute('aria-hidden', 'true');
});
});
describe('Available Icons', () => {
it('includes development icons', () => {
const icons = getAvailableIconNames();
expect(icons).toContain('clipboard-check');
expect(icons).toContain('briefcase');
expect(icons).toContain('code');
expect(icons).toContain('server');
});
it('includes design icons', () => {
const icons = getAvailableIconNames();
expect(icons).toContain('palette');
expect(icons).toContain('search');
});
it('includes quality icons', () => {
const icons = getAvailableIconNames();
expect(icons).toContain('shield');
expect(icons).toContain('shield-check');
});
it('includes ai_ml icons', () => {
const icons = getAvailableIconNames();
expect(icons).toContain('brain');
expect(icons).toContain('microscope');
expect(icons).toContain('eye');
});
it('includes data icons', () => {
const icons = getAvailableIconNames();
expect(icons).toContain('bar-chart');
expect(icons).toContain('database');
});
it('includes domain expert icons', () => {
const icons = getAvailableIconNames();
expect(icons).toContain('calculator');
expect(icons).toContain('heart-pulse');
expect(icons).toContain('flask-conical');
expect(icons).toContain('lightbulb');
expect(icons).toContain('book-open');
});
it('includes generic icons', () => {
const icons = getAvailableIconNames();
expect(icons).toContain('bot');
expect(icons).toContain('cpu');
});
});
describe('Icon Categories Coverage', () => {
const iconTestCases = [
// Development
{ name: 'clipboard-check', expectedClass: 'lucide-clipboard-check' },
{ name: 'briefcase', expectedClass: 'lucide-briefcase' },
{ name: 'file-text', expectedClass: 'lucide-file-text' },
{ name: 'git-branch', expectedClass: 'lucide-git-branch' },
{ name: 'layout', expectedClass: 'lucide-panels-top-left' },
{ name: 'smartphone', expectedClass: 'lucide-smartphone' },
// Operations
{ name: 'settings', expectedClass: 'lucide-settings' },
{ name: 'settings-2', expectedClass: 'lucide-settings-2' },
// AI/ML
{ name: 'message-square', expectedClass: 'lucide-message-square' },
// Leadership
{ name: 'users', expectedClass: 'lucide-users' },
{ name: 'target', expectedClass: 'lucide-target' },
];
it.each(iconTestCases)('renders $name icon correctly', ({ name, expectedClass }) => {
render(<DynamicIcon name={name} data-testid="icon" />);
expect(screen.getByTestId('icon')).toHaveClass(expectedClass);
});
});
});

View File

@@ -0,0 +1,212 @@
/**
* Tests for useValidationErrorHandler hook
*/
import { renderHook } from '@testing-library/react';
import { toast } from 'sonner';
import { useValidationErrorHandler } from '@/lib/forms/hooks/useValidationErrorHandler';
import type { FieldErrors } from 'react-hook-form';
// Mock sonner toast
jest.mock('sonner', () => ({
toast: {
error: jest.fn(),
},
}));
// Mock console.error to track debug logging
const originalConsoleError = console.error;
let consoleErrorMock: jest.SpyInstance;
beforeEach(() => {
jest.clearAllMocks();
consoleErrorMock = jest.spyOn(console, 'error').mockImplementation(() => {});
});
afterEach(() => {
consoleErrorMock.mockRestore();
console.error = originalConsoleError;
});
describe('useValidationErrorHandler', () => {
describe('basic functionality', () => {
it('shows toast with first error message', () => {
const { result } = renderHook(() => useValidationErrorHandler({ debug: false }));
const errors: FieldErrors = {
name: { message: 'Name is required', type: 'required' },
};
result.current.onValidationError(errors);
expect(toast.error).toHaveBeenCalledWith('Please fix form errors', {
description: 'name: Name is required',
});
});
it('uses custom toast title when provided', () => {
const { result } = renderHook(() =>
useValidationErrorHandler({
toastTitle: 'Validation Failed',
debug: false,
})
);
const errors: FieldErrors = {
email: { message: 'Invalid email', type: 'pattern' },
};
result.current.onValidationError(errors);
expect(toast.error).toHaveBeenCalledWith('Validation Failed', {
description: 'email: Invalid email',
});
});
it('does nothing when no errors', () => {
const { result } = renderHook(() => useValidationErrorHandler({ debug: false }));
result.current.onValidationError({});
expect(toast.error).not.toHaveBeenCalled();
});
});
describe('nested errors', () => {
it('handles nested field errors', () => {
const { result } = renderHook(() => useValidationErrorHandler({ debug: false }));
const errors: FieldErrors = {
model_params: {
temperature: { message: 'Temperature must be between 0 and 2', type: 'max' },
},
};
result.current.onValidationError(errors);
expect(toast.error).toHaveBeenCalledWith('Please fix form errors', {
description: 'model_params.temperature: Temperature must be between 0 and 2',
});
});
});
describe('tab navigation', () => {
it('navigates to correct tab when mapping provided', () => {
const setActiveTab = jest.fn();
const tabMapping = {
name: 'basic',
model_params: 'model',
};
const { result } = renderHook(() =>
useValidationErrorHandler({
tabMapping,
setActiveTab,
debug: false,
})
);
const errors: FieldErrors = {
model_params: {
temperature: { message: 'Invalid', type: 'type' },
},
};
result.current.onValidationError(errors);
expect(setActiveTab).toHaveBeenCalledWith('model');
});
it('does not navigate if field not in mapping', () => {
const setActiveTab = jest.fn();
const tabMapping = {
name: 'basic',
};
const { result } = renderHook(() =>
useValidationErrorHandler({
tabMapping,
setActiveTab,
debug: false,
})
);
const errors: FieldErrors = {
unknown_field: { message: 'Error', type: 'validation' },
};
result.current.onValidationError(errors);
expect(setActiveTab).not.toHaveBeenCalled();
});
it('does not crash when setActiveTab not provided', () => {
const tabMapping = { name: 'basic' };
const { result } = renderHook(() =>
useValidationErrorHandler({
tabMapping,
// setActiveTab not provided
debug: false,
})
);
const errors: FieldErrors = {
name: { message: 'Required', type: 'required' },
};
expect(() => result.current.onValidationError(errors)).not.toThrow();
});
});
describe('debug logging', () => {
it('logs errors when debug is true', () => {
const { result } = renderHook(() => useValidationErrorHandler({ debug: true }));
const errors: FieldErrors = {
name: { message: 'Required', type: 'required' },
};
result.current.onValidationError(errors);
expect(consoleErrorMock).toHaveBeenCalledWith('[Form Validation] Errors:', errors);
});
it('does not log errors when debug is false', () => {
const { result } = renderHook(() => useValidationErrorHandler({ debug: false }));
const errors: FieldErrors = {
name: { message: 'Required', type: 'required' },
};
result.current.onValidationError(errors);
expect(consoleErrorMock).not.toHaveBeenCalled();
});
});
describe('memoization', () => {
it('returns stable callback reference', () => {
const { result, rerender } = renderHook(() => useValidationErrorHandler({ debug: false }));
const firstCallback = result.current.onValidationError;
rerender();
const secondCallback = result.current.onValidationError;
expect(firstCallback).toBe(secondCallback);
});
it('returns new callback when options change', () => {
const { result, rerender } = renderHook(
({ title }) => useValidationErrorHandler({ toastTitle: title, debug: false }),
{ initialProps: { title: 'Error A' } }
);
const firstCallback = result.current.onValidationError;
rerender({ title: 'Error B' });
const secondCallback = result.current.onValidationError;
expect(firstCallback).not.toBe(secondCallback);
});
});
});

View File

@@ -0,0 +1,134 @@
/**
* Tests for getFirstValidationError utility
*/
import {
getFirstValidationError,
getAllValidationErrors,
} from '@/lib/forms/utils/getFirstValidationError';
import type { FieldErrors } from 'react-hook-form';
describe('getFirstValidationError', () => {
it('returns null for empty errors object', () => {
const result = getFirstValidationError({});
expect(result).toBeNull();
});
it('extracts direct error message', () => {
const errors: FieldErrors = {
name: { message: 'Name is required', type: 'required' },
};
const result = getFirstValidationError(errors);
expect(result).toEqual({ field: 'name', message: 'Name is required' });
});
it('extracts nested error message', () => {
const errors: FieldErrors = {
model_params: {
temperature: { message: 'Temperature must be a number', type: 'type' },
},
};
const result = getFirstValidationError(errors);
expect(result).toEqual({
field: 'model_params.temperature',
message: 'Temperature must be a number',
});
});
it('returns first error when multiple fields have errors', () => {
const errors: FieldErrors = {
name: { message: 'Name is required', type: 'required' },
slug: { message: 'Slug is required', type: 'required' },
};
const result = getFirstValidationError(errors);
// Object.keys order is insertion order, so 'name' comes first
expect(result?.field).toBe('name');
expect(result?.message).toBe('Name is required');
});
it('handles deeply nested errors', () => {
const errors: FieldErrors = {
config: {
nested: {
deep: { message: 'Deep error', type: 'validation' },
},
},
};
const result = getFirstValidationError(errors);
expect(result).toEqual({ field: 'config.nested.deep', message: 'Deep error' });
});
it('skips null error entries', () => {
const errors: FieldErrors = {
name: null as unknown as undefined,
slug: { message: 'Slug is required', type: 'required' },
};
const result = getFirstValidationError(errors);
expect(result).toEqual({ field: 'slug', message: 'Slug is required' });
});
it('handles error object with ref but no message', () => {
// react-hook-form errors may have 'ref' property but no 'message'
// We cast to FieldErrors to simulate edge cases
const errors = {
name: { type: 'required', ref: { current: null } },
slug: { message: 'Slug is required', type: 'required' },
} as unknown as FieldErrors;
const result = getFirstValidationError(errors);
// Should skip name (no message) and find slug
expect(result).toEqual({ field: 'slug', message: 'Slug is required' });
});
});
describe('getAllValidationErrors', () => {
it('returns empty array for empty errors object', () => {
const result = getAllValidationErrors({});
expect(result).toEqual([]);
});
it('returns all errors as flat array', () => {
const errors: FieldErrors = {
name: { message: 'Name is required', type: 'required' },
slug: { message: 'Slug is required', type: 'required' },
};
const result = getAllValidationErrors(errors);
expect(result).toHaveLength(2);
expect(result).toContainEqual({ field: 'name', message: 'Name is required' });
expect(result).toContainEqual({ field: 'slug', message: 'Slug is required' });
});
it('flattens nested errors', () => {
const errors: FieldErrors = {
model_params: {
temperature: { message: 'Invalid temperature', type: 'type' },
max_tokens: { message: 'Invalid max tokens', type: 'type' },
},
};
const result = getAllValidationErrors(errors);
expect(result).toHaveLength(2);
expect(result).toContainEqual({
field: 'model_params.temperature',
message: 'Invalid temperature',
});
expect(result).toContainEqual({
field: 'model_params.max_tokens',
message: 'Invalid max tokens',
});
});
it('combines direct and nested errors', () => {
const errors: FieldErrors = {
name: { message: 'Name is required', type: 'required' },
model_params: {
temperature: { message: 'Invalid temperature', type: 'type' },
},
};
const result = getAllValidationErrors(errors);
expect(result).toHaveLength(2);
expect(result).toContainEqual({ field: 'name', message: 'Name is required' });
expect(result).toContainEqual({
field: 'model_params.temperature',
message: 'Invalid temperature',
});
});
});

View File

@@ -0,0 +1,256 @@
/**
* Tests for mergeWithDefaults utilities
*/
import {
safeValue,
isNumber,
isString,
isBoolean,
isArray,
isObject,
deepMergeWithDefaults,
createFormInitializer,
} from '@/lib/forms/utils/mergeWithDefaults';
describe('Type Guards', () => {
describe('isNumber', () => {
it('returns true for valid numbers', () => {
expect(isNumber(0)).toBe(true);
expect(isNumber(42)).toBe(true);
expect(isNumber(-10)).toBe(true);
expect(isNumber(3.14)).toBe(true);
});
it('returns false for NaN', () => {
expect(isNumber(NaN)).toBe(false);
});
it('returns false for non-numbers', () => {
expect(isNumber('42')).toBe(false);
expect(isNumber(null)).toBe(false);
expect(isNumber(undefined)).toBe(false);
expect(isNumber({})).toBe(false);
});
});
describe('isString', () => {
it('returns true for strings', () => {
expect(isString('')).toBe(true);
expect(isString('hello')).toBe(true);
});
it('returns false for non-strings', () => {
expect(isString(42)).toBe(false);
expect(isString(null)).toBe(false);
expect(isString(undefined)).toBe(false);
});
});
describe('isBoolean', () => {
it('returns true for booleans', () => {
expect(isBoolean(true)).toBe(true);
expect(isBoolean(false)).toBe(true);
});
it('returns false for non-booleans', () => {
expect(isBoolean(0)).toBe(false);
expect(isBoolean(1)).toBe(false);
expect(isBoolean('true')).toBe(false);
});
});
describe('isArray', () => {
it('returns true for arrays', () => {
expect(isArray([])).toBe(true);
expect(isArray([1, 2, 3])).toBe(true);
});
it('returns false for non-arrays', () => {
expect(isArray({})).toBe(false);
expect(isArray('array')).toBe(false);
expect(isArray(null)).toBe(false);
});
it('validates item types when itemCheck provided', () => {
expect(isArray([1, 2, 3], isNumber)).toBe(true);
expect(isArray(['a', 'b'], isString)).toBe(true);
expect(isArray([1, 'two', 3], isNumber)).toBe(false);
});
});
describe('isObject', () => {
it('returns true for plain objects', () => {
expect(isObject({})).toBe(true);
expect(isObject({ key: 'value' })).toBe(true);
});
it('returns false for null', () => {
expect(isObject(null)).toBe(false);
});
it('returns false for arrays', () => {
expect(isObject([])).toBe(false);
});
it('returns false for primitives', () => {
expect(isObject('string')).toBe(false);
expect(isObject(42)).toBe(false);
});
});
});
describe('safeValue', () => {
it('returns value when type check passes', () => {
expect(safeValue(42, 0, isNumber)).toBe(42);
expect(safeValue('hello', '', isString)).toBe('hello');
});
it('returns default when type check fails', () => {
expect(safeValue('not a number', 0, isNumber)).toBe(0);
expect(safeValue(42, '', isString)).toBe('');
});
it('returns default for null/undefined', () => {
expect(safeValue(null, 0, isNumber)).toBe(0);
expect(safeValue(undefined, 'default', isString)).toBe('default');
});
});
describe('deepMergeWithDefaults', () => {
it('returns defaults when source is null', () => {
const defaults = { name: 'default', value: 10 };
expect(deepMergeWithDefaults(defaults, null)).toEqual(defaults);
});
it('returns defaults when source is undefined', () => {
const defaults = { name: 'default', value: 10 };
expect(deepMergeWithDefaults(defaults, undefined)).toEqual(defaults);
});
it('merges source values over defaults', () => {
const defaults = { name: 'default', value: 10 };
const source = { name: 'custom' };
expect(deepMergeWithDefaults(defaults, source)).toEqual({
name: 'custom',
value: 10,
});
});
it('preserves default for missing source keys', () => {
const defaults = { a: 1, b: 2, c: 3 };
const source = { a: 10 };
expect(deepMergeWithDefaults(defaults, source)).toEqual({
a: 10,
b: 2,
c: 3,
});
});
it('recursively merges nested objects', () => {
const defaults = {
config: { temperature: 0.7, max_tokens: 8192 },
};
// Source has partial nested config - deepMerge fills in missing fields
const source = {
config: { temperature: 0.5 },
} as unknown as Partial<typeof defaults>;
expect(deepMergeWithDefaults(defaults, source)).toEqual({
config: { temperature: 0.5, max_tokens: 8192 },
});
});
it('only uses source values if types match', () => {
const defaults = { value: 10, name: 'default' };
const source = { value: 'not a number' as unknown as number };
expect(deepMergeWithDefaults(defaults, source)).toEqual({
value: 10,
name: 'default',
});
});
it('handles arrays - uses source array if types match', () => {
const defaults = { items: ['a', 'b'] };
const source = { items: ['c', 'd', 'e'] };
expect(deepMergeWithDefaults(defaults, source)).toEqual({
items: ['c', 'd', 'e'],
});
});
it('skips undefined source values', () => {
const defaults = { name: 'default' };
const source = { name: undefined };
expect(deepMergeWithDefaults(defaults, source)).toEqual({
name: 'default',
});
});
it('handles null values when defaults are null', () => {
const defaults = { value: null };
const source = { value: null };
expect(deepMergeWithDefaults(defaults, source)).toEqual({ value: null });
});
it('uses source value when default is null but source has value', () => {
const defaults = { description: null as string | null, name: '' };
const source = { description: 'A real description', name: 'Test' };
expect(deepMergeWithDefaults(defaults, source)).toEqual({
description: 'A real description',
name: 'Test',
});
});
});
describe('createFormInitializer', () => {
it('returns defaults when called with null', () => {
const defaults = { name: '', age: 0 };
const initializer = createFormInitializer(defaults);
expect(initializer(null)).toEqual(defaults);
});
it('returns defaults when called with undefined', () => {
const defaults = { name: '', age: 0 };
const initializer = createFormInitializer(defaults);
expect(initializer(undefined)).toEqual(defaults);
});
it('merges API data with defaults', () => {
const defaults = { name: '', age: 0, active: false };
const initializer = createFormInitializer(defaults);
const result = initializer({ name: 'John', age: 25 });
expect(result).toEqual({ name: 'John', age: 25, active: false });
});
it('uses custom transform function when provided', () => {
interface Form {
fullName: string;
isActive: boolean;
}
interface Api {
first_name: string;
last_name: string;
active: boolean;
}
const defaults: Form = { fullName: '', isActive: false };
const initializer = createFormInitializer<Form, Api>(defaults, (apiData, defs) => ({
fullName: apiData ? `${apiData.first_name} ${apiData.last_name}` : defs.fullName,
isActive: apiData?.active ?? defs.isActive,
}));
const result = initializer({
first_name: 'John',
last_name: 'Doe',
active: true,
});
expect(result).toEqual({ fullName: 'John Doe', isActive: true });
});
it('transform receives defaults when apiData is null', () => {
const defaults = { name: 'default' };
const initializer = createFormInitializer(defaults, (_, defs) => ({
name: defs.name.toUpperCase(),
}));
expect(initializer(null)).toEqual({ name: 'DEFAULT' });
});
});

View File

@@ -27,6 +27,9 @@ jest.mock('@/config/app.config', () => ({
debug: {
api: false,
},
demo: {
enabled: false,
},
},
}));
@@ -649,6 +652,9 @@ describe('useProjectEvents', () => {
debug: {
api: true,
},
demo: {
enabled: false,
},
},
}));

View File

@@ -0,0 +1,67 @@
# Git Operations MCP Server Dockerfile
# Multi-stage build for smaller production image
FROM python:3.12-slim AS builder
# Install build dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
git \
&& rm -rf /var/lib/apt/lists/*
# Install uv for fast package management
RUN pip install --no-cache-dir uv
# Create app directory
WORKDIR /app
# Copy dependency files
COPY pyproject.toml .
# Install dependencies with uv
RUN uv pip install --system --no-cache .
# Production stage
FROM python:3.12-slim
# Install runtime dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
git \
openssh-client \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd --create-home --shell /bin/bash syndarix
# Create workspace directory
RUN mkdir -p /var/syndarix/workspaces && chown -R syndarix:syndarix /var/syndarix
# Create app directory
WORKDIR /app
# Copy installed packages from builder
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
COPY --from=builder /usr/local/bin /usr/local/bin
# Copy application code
COPY --chown=syndarix:syndarix . .
# Set Python path
ENV PYTHONPATH=/app
ENV PYTHONUNBUFFERED=1
# Configure git for the container
RUN git config --global --add safe.directory '*'
# Switch to non-root user
USER syndarix
# Expose port
EXPOSE 8003
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()" || exit 1
# Run the server
CMD ["python", "server.py"]

View File

@@ -0,0 +1,88 @@
.PHONY: help install install-dev lint lint-fix format format-check type-check test test-cov validate clean run
# Ensure commands in this project don't inherit an external Python virtualenv
# (prevents uv warnings about mismatched VIRTUAL_ENV when running from repo root)
unexport VIRTUAL_ENV
# Default target
help:
@echo "Git Operations MCP Server - Development Commands"
@echo ""
@echo "Setup:"
@echo " make install - Install production dependencies"
@echo " make install-dev - Install development dependencies"
@echo ""
@echo "Quality Checks:"
@echo " make lint - Run Ruff linter"
@echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff"
@echo " make format-check - Check if code is formatted"
@echo " make type-check - Run mypy type checker"
@echo ""
@echo "Testing:"
@echo " make test - Run pytest"
@echo " make test-cov - Run pytest with coverage"
@echo ""
@echo "All-in-one:"
@echo " make validate - Run all checks (lint + format + types)"
@echo ""
@echo "Running:"
@echo " make run - Run the server locally"
@echo ""
@echo "Cleanup:"
@echo " make clean - Remove cache and build artifacts"
# Setup
install:
@echo "Installing production dependencies..."
@uv pip install -e .
install-dev:
@echo "Installing development dependencies..."
@uv pip install -e ".[dev]"
# Quality checks
lint:
@echo "Running Ruff linter..."
@uv run ruff check .
lint-fix:
@echo "Running Ruff linter with auto-fix..."
@uv run ruff check --fix .
format:
@echo "Formatting code..."
@uv run ruff format .
format-check:
@echo "Checking code formatting..."
@uv run ruff format --check .
type-check:
@echo "Running mypy..."
@uv run python -m mypy server.py config.py models.py exceptions.py git_wrapper.py workspace.py providers/ --explicit-package-bases
# Testing
test:
@echo "Running tests..."
@IS_TEST=True uv run pytest tests/ -v
test-cov:
@echo "Running tests with coverage..."
@IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
# All-in-one validation
validate: lint format-check type-check
@echo "All validations passed!"
# Running
run:
@echo "Starting Git Operations server..."
@uv run python server.py
# Cleanup
clean:
@echo "Cleaning up..."
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
@find . -type f -name "*.pyc" -delete 2>/dev/null || true

View File

@@ -0,0 +1,179 @@
"""
Git Operations MCP Server.
Provides git repository management, branching, commits, and PR workflows
for Syndarix AI agents.
"""
__version__ = "0.1.0"
from config import Settings, get_settings, is_test_mode, reset_settings
from exceptions import (
APIError,
AuthenticationError,
BranchExistsError,
BranchNotFoundError,
CheckoutError,
CloneError,
CommitError,
CredentialError,
CredentialNotFoundError,
DirtyWorkspaceError,
ErrorCode,
GitError,
GitOpsError,
InvalidRefError,
MergeConflictError,
PRError,
PRNotFoundError,
ProviderError,
ProviderNotFoundError,
PullError,
PushError,
WorkspaceError,
WorkspaceLockedError,
WorkspaceNotFoundError,
WorkspaceSizeExceededError,
)
from models import (
BranchInfo,
BranchRequest,
BranchResult,
CheckoutRequest,
CheckoutResult,
CloneRequest,
CloneResult,
CommitInfo,
CommitRequest,
CommitResult,
CreatePRRequest,
CreatePRResult,
DiffHunk,
DiffRequest,
DiffResult,
FileChange,
FileChangeType,
FileDiff,
GetPRRequest,
GetPRResult,
GetWorkspaceRequest,
GetWorkspaceResult,
HealthStatus,
ListBranchesRequest,
ListBranchesResult,
ListPRsRequest,
ListPRsResult,
LockWorkspaceRequest,
LockWorkspaceResult,
LogRequest,
LogResult,
MergePRRequest,
MergePRResult,
MergeStrategy,
PRInfo,
ProviderStatus,
ProviderType,
PRState,
PullRequest,
PullResult,
PushRequest,
PushResult,
StatusRequest,
StatusResult,
UnlockWorkspaceRequest,
UnlockWorkspaceResult,
UpdatePRRequest,
UpdatePRResult,
WorkspaceInfo,
WorkspaceState,
)
__all__ = [
# Version
"__version__",
# Config
"Settings",
"get_settings",
"reset_settings",
"is_test_mode",
# Error codes
"ErrorCode",
# Exceptions
"GitOpsError",
"WorkspaceError",
"WorkspaceNotFoundError",
"WorkspaceLockedError",
"WorkspaceSizeExceededError",
"GitError",
"CloneError",
"CheckoutError",
"CommitError",
"PushError",
"PullError",
"MergeConflictError",
"BranchExistsError",
"BranchNotFoundError",
"InvalidRefError",
"DirtyWorkspaceError",
"ProviderError",
"AuthenticationError",
"ProviderNotFoundError",
"PRError",
"PRNotFoundError",
"APIError",
"CredentialError",
"CredentialNotFoundError",
# Enums
"FileChangeType",
"MergeStrategy",
"PRState",
"ProviderType",
"WorkspaceState",
# Dataclasses
"FileChange",
"BranchInfo",
"CommitInfo",
"DiffHunk",
"FileDiff",
"PRInfo",
"WorkspaceInfo",
# Request/Response models
"CloneRequest",
"CloneResult",
"StatusRequest",
"StatusResult",
"BranchRequest",
"BranchResult",
"ListBranchesRequest",
"ListBranchesResult",
"CheckoutRequest",
"CheckoutResult",
"CommitRequest",
"CommitResult",
"PushRequest",
"PushResult",
"PullRequest",
"PullResult",
"DiffRequest",
"DiffResult",
"LogRequest",
"LogResult",
"CreatePRRequest",
"CreatePRResult",
"GetPRRequest",
"GetPRResult",
"ListPRsRequest",
"ListPRsResult",
"MergePRRequest",
"MergePRResult",
"UpdatePRRequest",
"UpdatePRResult",
"GetWorkspaceRequest",
"GetWorkspaceResult",
"LockWorkspaceRequest",
"LockWorkspaceResult",
"UnlockWorkspaceRequest",
"UnlockWorkspaceResult",
"HealthStatus",
"ProviderStatus",
]

View File

@@ -0,0 +1,155 @@
"""
Configuration for Git Operations MCP Server.
Uses pydantic-settings for environment variable loading.
"""
import os
from pathlib import Path
from pydantic import Field
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings loaded from environment."""
# Server settings
host: str = Field(default="0.0.0.0", description="Server host")
port: int = Field(default=8003, description="Server port")
debug: bool = Field(default=False, description="Debug mode")
# Workspace settings
workspace_base_path: Path = Field(
default=Path("/var/syndarix/workspaces"),
description="Base path for git workspaces",
)
workspace_max_size_gb: float = Field(
default=10.0,
description="Maximum size per workspace in GB",
)
workspace_stale_days: int = Field(
default=7,
description="Days after which unused workspace is considered stale",
)
workspace_lock_timeout: int = Field(
default=300,
description="Workspace lock timeout in seconds",
)
# Git settings
git_timeout: int = Field(
default=120,
description="Default timeout for git operations in seconds",
)
git_clone_timeout: int = Field(
default=600,
description="Timeout for clone operations in seconds",
)
git_author_name: str = Field(
default="Syndarix Agent",
description="Default author name for commits",
)
git_author_email: str = Field(
default="agent@syndarix.ai",
description="Default author email for commits",
)
git_max_diff_lines: int = Field(
default=10000,
description="Maximum lines in diff output",
)
# Redis settings (for distributed locking)
redis_url: str = Field(
default="redis://localhost:6379/0",
description="Redis connection URL",
)
# Provider settings
gitea_base_url: str = Field(
default="",
description="Gitea API base URL (e.g., https://gitea.example.com)",
)
gitea_token: str = Field(
default="",
description="Gitea API token",
)
github_token: str = Field(
default="",
description="GitHub API token",
)
github_api_url: str = Field(
default="https://api.github.com",
description="GitHub API URL (for Enterprise)",
)
gitlab_token: str = Field(
default="",
description="GitLab API token",
)
gitlab_url: str = Field(
default="https://gitlab.com",
description="GitLab URL (for self-hosted)",
)
# Rate limiting
rate_limit_requests: int = Field(
default=100,
description="Max API requests per minute per provider",
)
rate_limit_window: int = Field(
default=60,
description="Rate limit window in seconds",
)
# Retry settings
retry_attempts: int = Field(
default=3,
description="Number of retry attempts for failed operations",
)
retry_delay: float = Field(
default=1.0,
description="Initial retry delay in seconds",
)
retry_max_delay: float = Field(
default=30.0,
description="Maximum retry delay in seconds",
)
# Security settings
allowed_hosts: list[str] = Field(
default_factory=list,
description="Allowed git host domains (empty = all)",
)
max_clone_size_mb: int = Field(
default=500,
description="Maximum repository size for clone in MB",
)
enable_force_push: bool = Field(
default=False,
description="Allow force push operations",
)
model_config = {"env_prefix": "GIT_OPS_", "env_file": ".env", "extra": "ignore"}
# Global settings instance (lazy initialization)
_settings: Settings | None = None
def get_settings() -> Settings:
"""Get the global settings instance."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def reset_settings() -> None:
"""Reset the global settings (for testing)."""
global _settings
_settings = None
def is_test_mode() -> bool:
"""Check if running in test mode."""
return os.getenv("IS_TEST", "").lower() in ("true", "1", "yes")

View File

@@ -0,0 +1,359 @@
"""
Exception hierarchy for Git Operations MCP Server.
Provides structured error handling with error codes for MCP responses.
"""
from enum import Enum
from typing import Any
class ErrorCode(str, Enum):
"""Error codes for Git Operations errors."""
# General errors (1xxx)
INTERNAL_ERROR = "GIT_1000"
INVALID_REQUEST = "GIT_1001"
NOT_FOUND = "GIT_1002"
PERMISSION_DENIED = "GIT_1003"
TIMEOUT = "GIT_1004"
RATE_LIMITED = "GIT_1005"
# Workspace errors (2xxx)
WORKSPACE_NOT_FOUND = "GIT_2000"
WORKSPACE_LOCKED = "GIT_2001"
WORKSPACE_SIZE_EXCEEDED = "GIT_2002"
WORKSPACE_CREATE_FAILED = "GIT_2003"
WORKSPACE_DELETE_FAILED = "GIT_2004"
# Git operation errors (3xxx)
CLONE_FAILED = "GIT_3000"
CHECKOUT_FAILED = "GIT_3001"
COMMIT_FAILED = "GIT_3002"
PUSH_FAILED = "GIT_3003"
PULL_FAILED = "GIT_3004"
MERGE_CONFLICT = "GIT_3005"
BRANCH_EXISTS = "GIT_3006"
BRANCH_NOT_FOUND = "GIT_3007"
INVALID_REF = "GIT_3008"
DIRTY_WORKSPACE = "GIT_3009"
UNCOMMITTED_CHANGES = "GIT_3010"
FETCH_FAILED = "GIT_3011"
RESET_FAILED = "GIT_3012"
# Provider errors (4xxx)
PROVIDER_ERROR = "GIT_4000"
PROVIDER_AUTH_FAILED = "GIT_4001"
PROVIDER_NOT_FOUND = "GIT_4002"
PR_CREATE_FAILED = "GIT_4003"
PR_MERGE_FAILED = "GIT_4004"
PR_NOT_FOUND = "GIT_4005"
API_ERROR = "GIT_4006"
# Credential errors (5xxx)
CREDENTIAL_ERROR = "GIT_5000"
CREDENTIAL_NOT_FOUND = "GIT_5001"
CREDENTIAL_INVALID = "GIT_5002"
SSH_KEY_ERROR = "GIT_5003"
class GitOpsError(Exception):
"""Base exception for Git Operations errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.INTERNAL_ERROR,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message)
self.message = message
self.code = code
self.details = details or {}
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for MCP response."""
result: dict[str, Any] = {
"error": self.message,
"code": self.code.value,
}
if self.details:
result["details"] = self.details
return result
# Workspace Errors
class WorkspaceError(GitOpsError):
"""Base exception for workspace-related errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.WORKSPACE_NOT_FOUND,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message, code, details)
class WorkspaceNotFoundError(WorkspaceError):
"""Workspace does not exist."""
def __init__(self, project_id: str) -> None:
super().__init__(
f"Workspace not found for project: {project_id}",
ErrorCode.WORKSPACE_NOT_FOUND,
{"project_id": project_id},
)
class WorkspaceLockedError(WorkspaceError):
"""Workspace is locked by another operation."""
def __init__(self, project_id: str, holder: str | None = None) -> None:
details: dict[str, Any] = {"project_id": project_id}
if holder:
details["locked_by"] = holder
super().__init__(
f"Workspace is locked for project: {project_id}",
ErrorCode.WORKSPACE_LOCKED,
details,
)
class WorkspaceSizeExceededError(WorkspaceError):
"""Workspace size limit exceeded."""
def __init__(self, project_id: str, current_size: float, max_size: float) -> None:
super().__init__(
f"Workspace size limit exceeded for project: {project_id}",
ErrorCode.WORKSPACE_SIZE_EXCEEDED,
{
"project_id": project_id,
"current_size_gb": current_size,
"max_size_gb": max_size,
},
)
# Git Operation Errors
class GitError(GitOpsError):
"""Base exception for git operation errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.INTERNAL_ERROR,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message, code, details)
class CloneError(GitError):
"""Failed to clone repository."""
def __init__(self, repo_url: str, reason: str) -> None:
super().__init__(
f"Failed to clone repository: {reason}",
ErrorCode.CLONE_FAILED,
{"repo_url": repo_url, "reason": reason},
)
class CheckoutError(GitError):
"""Failed to checkout branch or ref."""
def __init__(self, ref: str, reason: str) -> None:
super().__init__(
f"Failed to checkout '{ref}': {reason}",
ErrorCode.CHECKOUT_FAILED,
{"ref": ref, "reason": reason},
)
class CommitError(GitError):
"""Failed to commit changes."""
def __init__(self, reason: str) -> None:
super().__init__(
f"Failed to commit: {reason}",
ErrorCode.COMMIT_FAILED,
{"reason": reason},
)
class PushError(GitError):
"""Failed to push to remote."""
def __init__(self, branch: str, reason: str) -> None:
super().__init__(
f"Failed to push branch '{branch}': {reason}",
ErrorCode.PUSH_FAILED,
{"branch": branch, "reason": reason},
)
class PullError(GitError):
"""Failed to pull from remote."""
def __init__(self, branch: str, reason: str) -> None:
super().__init__(
f"Failed to pull branch '{branch}': {reason}",
ErrorCode.PULL_FAILED,
{"branch": branch, "reason": reason},
)
class MergeConflictError(GitError):
"""Merge conflict detected."""
def __init__(self, conflicting_files: list[str]) -> None:
super().__init__(
f"Merge conflict detected in {len(conflicting_files)} files",
ErrorCode.MERGE_CONFLICT,
{"conflicting_files": conflicting_files},
)
class BranchExistsError(GitError):
"""Branch already exists."""
def __init__(self, branch_name: str) -> None:
super().__init__(
f"Branch already exists: {branch_name}",
ErrorCode.BRANCH_EXISTS,
{"branch": branch_name},
)
class BranchNotFoundError(GitError):
"""Branch does not exist."""
def __init__(self, branch_name: str) -> None:
super().__init__(
f"Branch not found: {branch_name}",
ErrorCode.BRANCH_NOT_FOUND,
{"branch": branch_name},
)
class InvalidRefError(GitError):
"""Invalid git reference."""
def __init__(self, ref: str) -> None:
super().__init__(
f"Invalid git reference: {ref}",
ErrorCode.INVALID_REF,
{"ref": ref},
)
class DirtyWorkspaceError(GitError):
"""Workspace has uncommitted changes."""
def __init__(self, modified_files: list[str]) -> None:
super().__init__(
f"Workspace has {len(modified_files)} uncommitted changes",
ErrorCode.DIRTY_WORKSPACE,
{"modified_files": modified_files[:10]}, # Limit to first 10
)
# Provider Errors
class ProviderError(GitOpsError):
"""Base exception for provider-related errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.PROVIDER_ERROR,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message, code, details)
class AuthenticationError(ProviderError):
"""Authentication with provider failed."""
def __init__(self, provider: str, reason: str) -> None:
super().__init__(
f"Authentication failed with {provider}: {reason}",
ErrorCode.PROVIDER_AUTH_FAILED,
{"provider": provider, "reason": reason},
)
class ProviderNotFoundError(ProviderError):
"""Provider not configured or recognized."""
def __init__(self, provider: str) -> None:
super().__init__(
f"Provider not found or not configured: {provider}",
ErrorCode.PROVIDER_NOT_FOUND,
{"provider": provider},
)
class PRError(ProviderError):
"""Pull request operation failed."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.PR_CREATE_FAILED,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message, code, details)
class PRNotFoundError(PRError):
"""Pull request not found."""
def __init__(self, pr_number: int, repo: str) -> None:
super().__init__(
f"Pull request #{pr_number} not found in {repo}",
ErrorCode.PR_NOT_FOUND,
{"pr_number": pr_number, "repo": repo},
)
class APIError(ProviderError):
"""Provider API error."""
def __init__(self, provider: str, status_code: int, message: str) -> None:
super().__init__(
f"{provider} API error ({status_code}): {message}",
ErrorCode.API_ERROR,
{"provider": provider, "status_code": status_code, "message": message},
)
# Credential Errors
class CredentialError(GitOpsError):
"""Base exception for credential-related errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.CREDENTIAL_ERROR,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message, code, details)
class CredentialNotFoundError(CredentialError):
"""Credential not found."""
def __init__(self, credential_type: str, identifier: str) -> None:
super().__init__(
f"{credential_type} credential not found: {identifier}",
ErrorCode.CREDENTIAL_NOT_FOUND,
{"type": credential_type, "identifier": identifier},
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,690 @@
"""
Data models for Git Operations MCP Server.
Defines data structures for git operations, workspace management,
and provider interactions.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class FileChangeType(str, Enum):
"""Types of file changes in git."""
ADDED = "added"
MODIFIED = "modified"
DELETED = "deleted"
RENAMED = "renamed"
COPIED = "copied"
UNTRACKED = "untracked"
IGNORED = "ignored"
class MergeStrategy(str, Enum):
"""Merge strategies for pull requests."""
MERGE = "merge" # Create a merge commit
SQUASH = "squash" # Squash and merge
REBASE = "rebase" # Rebase and merge
class PRState(str, Enum):
"""Pull request states."""
OPEN = "open"
CLOSED = "closed"
MERGED = "merged"
class ProviderType(str, Enum):
"""Supported git providers."""
GITEA = "gitea"
GITHUB = "github"
GITLAB = "gitlab"
class WorkspaceState(str, Enum):
"""Workspace lifecycle states."""
INITIALIZING = "initializing"
READY = "ready"
LOCKED = "locked"
STALE = "stale"
DELETED = "deleted"
# Dataclasses for internal data structures
@dataclass
class FileChange:
"""A file change in git status."""
path: str
change_type: FileChangeType
old_path: str | None = None # For renames
additions: int = 0
deletions: int = 0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"path": self.path,
"change_type": self.change_type.value,
"old_path": self.old_path,
"additions": self.additions,
"deletions": self.deletions,
}
@dataclass
class BranchInfo:
"""Information about a git branch."""
name: str
is_current: bool = False
is_remote: bool = False
tracking_branch: str | None = None
commit_sha: str | None = None
commit_message: str | None = None
ahead: int = 0
behind: int = 0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"name": self.name,
"is_current": self.is_current,
"is_remote": self.is_remote,
"tracking_branch": self.tracking_branch,
"commit_sha": self.commit_sha,
"commit_message": self.commit_message,
"ahead": self.ahead,
"behind": self.behind,
}
@dataclass
class CommitInfo:
"""Information about a git commit."""
sha: str
short_sha: str
message: str
author_name: str
author_email: str
authored_date: datetime
committer_name: str
committer_email: str
committed_date: datetime
parents: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"sha": self.sha,
"short_sha": self.short_sha,
"message": self.message,
"author_name": self.author_name,
"author_email": self.author_email,
"authored_date": self.authored_date.isoformat(),
"committer_name": self.committer_name,
"committer_email": self.committer_email,
"committed_date": self.committed_date.isoformat(),
"parents": self.parents,
}
@dataclass
class DiffHunk:
"""A hunk of diff content."""
old_start: int
old_lines: int
new_start: int
new_lines: int
content: str
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"old_start": self.old_start,
"old_lines": self.old_lines,
"new_start": self.new_start,
"new_lines": self.new_lines,
"content": self.content,
}
@dataclass
class FileDiff:
"""Diff for a single file."""
path: str
change_type: FileChangeType
old_path: str | None = None
hunks: list[DiffHunk] = field(default_factory=list)
additions: int = 0
deletions: int = 0
is_binary: bool = False
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"path": self.path,
"change_type": self.change_type.value,
"old_path": self.old_path,
"hunks": [h.to_dict() for h in self.hunks],
"additions": self.additions,
"deletions": self.deletions,
"is_binary": self.is_binary,
}
@dataclass
class PRInfo:
"""Information about a pull request."""
number: int
title: str
body: str
state: PRState
source_branch: str
target_branch: str
author: str
created_at: datetime
updated_at: datetime
merged_at: datetime | None = None
closed_at: datetime | None = None
url: str | None = None
labels: list[str] = field(default_factory=list)
assignees: list[str] = field(default_factory=list)
reviewers: list[str] = field(default_factory=list)
mergeable: bool | None = None
draft: bool = False
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"number": self.number,
"title": self.title,
"body": self.body,
"state": self.state.value,
"source_branch": self.source_branch,
"target_branch": self.target_branch,
"author": self.author,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"merged_at": self.merged_at.isoformat() if self.merged_at else None,
"closed_at": self.closed_at.isoformat() if self.closed_at else None,
"url": self.url,
"labels": self.labels,
"assignees": self.assignees,
"reviewers": self.reviewers,
"mergeable": self.mergeable,
"draft": self.draft,
}
@dataclass
class WorkspaceInfo:
"""Information about a project workspace."""
project_id: str
path: str
state: WorkspaceState
repo_url: str | None = None
current_branch: str | None = None
last_accessed: datetime = field(default_factory=lambda: datetime.now(UTC))
size_bytes: int = 0
lock_holder: str | None = None
lock_expires: datetime | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"project_id": self.project_id,
"path": self.path,
"state": self.state.value,
"repo_url": self.repo_url,
"current_branch": self.current_branch,
"last_accessed": self.last_accessed.isoformat(),
"size_bytes": self.size_bytes,
"lock_holder": self.lock_holder,
"lock_expires": self.lock_expires.isoformat()
if self.lock_expires
else None,
}
# Pydantic Request/Response Models
class CloneRequest(BaseModel):
"""Request to clone a repository."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
repo_url: str = Field(..., description="Repository URL to clone")
branch: str | None = Field(
default=None, description="Branch to checkout after clone"
)
depth: int | None = Field(
default=None, ge=1, description="Shallow clone depth (None = full clone)"
)
class CloneResult(BaseModel):
"""Result of a clone operation."""
success: bool = Field(..., description="Whether clone succeeded")
project_id: str = Field(..., description="Project ID")
workspace_path: str = Field(..., description="Path to cloned workspace")
branch: str = Field(..., description="Current branch after clone")
commit_sha: str = Field(..., description="HEAD commit SHA")
error: str | None = Field(default=None, description="Error message if failed")
class StatusRequest(BaseModel):
"""Request for git status."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
include_untracked: bool = Field(default=True, description="Include untracked files")
class StatusResult(BaseModel):
"""Result of a status operation."""
project_id: str = Field(..., description="Project ID")
branch: str = Field(..., description="Current branch")
commit_sha: str = Field(..., description="HEAD commit SHA")
is_clean: bool = Field(..., description="Whether working tree is clean")
staged: list[dict[str, Any]] = Field(
default_factory=list, description="Staged changes"
)
unstaged: list[dict[str, Any]] = Field(
default_factory=list, description="Unstaged changes"
)
untracked: list[str] = Field(default_factory=list, description="Untracked files")
ahead: int = Field(default=0, description="Commits ahead of upstream")
behind: int = Field(default=0, description="Commits behind upstream")
class BranchRequest(BaseModel):
"""Request for branch operations."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
branch_name: str = Field(..., description="Branch name")
from_ref: str | None = Field(
default=None, description="Reference to create branch from"
)
checkout: bool = Field(default=True, description="Checkout after creation")
class BranchResult(BaseModel):
"""Result of a branch operation."""
success: bool = Field(..., description="Whether operation succeeded")
branch: str = Field(..., description="Branch name")
commit_sha: str | None = Field(default=None, description="HEAD commit SHA")
is_current: bool = Field(default=False, description="Whether branch is checked out")
error: str | None = Field(default=None, description="Error message if failed")
class ListBranchesRequest(BaseModel):
"""Request to list branches."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
include_remote: bool = Field(default=False, description="Include remote branches")
class ListBranchesResult(BaseModel):
"""Result of listing branches."""
project_id: str = Field(..., description="Project ID")
current_branch: str = Field(..., description="Currently checked out branch")
local_branches: list[dict[str, Any]] = Field(
default_factory=list, description="Local branches"
)
remote_branches: list[dict[str, Any]] = Field(
default_factory=list, description="Remote branches"
)
class CheckoutRequest(BaseModel):
"""Request to checkout a branch or ref."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
ref: str = Field(..., description="Branch, tag, or commit to checkout")
create_branch: bool = Field(default=False, description="Create new branch")
force: bool = Field(default=False, description="Force checkout (discard changes)")
class CheckoutResult(BaseModel):
"""Result of a checkout operation."""
success: bool = Field(..., description="Whether checkout succeeded")
ref: str = Field(..., description="Checked out reference")
commit_sha: str | None = Field(default=None, description="HEAD commit SHA")
error: str | None = Field(default=None, description="Error message if failed")
class CommitRequest(BaseModel):
"""Request to create a commit."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
message: str = Field(..., description="Commit message")
files: list[str] | None = Field(
default=None, description="Files to commit (None = all staged)"
)
author_name: str | None = Field(default=None, description="Author name override")
author_email: str | None = Field(default=None, description="Author email override")
allow_empty: bool = Field(default=False, description="Allow empty commit")
class CommitResult(BaseModel):
"""Result of a commit operation."""
success: bool = Field(..., description="Whether commit succeeded")
commit_sha: str | None = Field(default=None, description="New commit SHA")
short_sha: str | None = Field(default=None, description="Short commit SHA")
message: str | None = Field(default=None, description="Commit message")
files_changed: int = Field(default=0, description="Number of files changed")
insertions: int = Field(default=0, description="Lines added")
deletions: int = Field(default=0, description="Lines removed")
error: str | None = Field(default=None, description="Error message if failed")
class PushRequest(BaseModel):
"""Request to push to remote."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
branch: str | None = Field(
default=None, description="Branch to push (None = current)"
)
remote: str = Field(default="origin", description="Remote name")
force: bool = Field(default=False, description="Force push")
set_upstream: bool = Field(default=True, description="Set upstream tracking")
class PushResult(BaseModel):
"""Result of a push operation."""
success: bool = Field(..., description="Whether push succeeded")
branch: str = Field(..., description="Pushed branch")
remote: str = Field(..., description="Remote name")
commits_pushed: int = Field(default=0, description="Number of commits pushed")
error: str | None = Field(default=None, description="Error message if failed")
class PullRequest(BaseModel):
"""Request to pull from remote."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
branch: str | None = Field(
default=None, description="Branch to pull (None = current)"
)
remote: str = Field(default="origin", description="Remote name")
rebase: bool = Field(default=False, description="Rebase instead of merge")
class PullResult(BaseModel):
"""Result of a pull operation."""
success: bool = Field(..., description="Whether pull succeeded")
branch: str = Field(..., description="Pulled branch")
commits_received: int = Field(default=0, description="New commits received")
fast_forward: bool = Field(default=False, description="Was fast-forward")
conflicts: list[str] = Field(
default_factory=list, description="Conflicting files if any"
)
error: str | None = Field(default=None, description="Error message if failed")
class DiffRequest(BaseModel):
"""Request for diff."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
base: str | None = Field(
default=None, description="Base reference (None = working tree)"
)
head: str | None = Field(default=None, description="Head reference (None = HEAD)")
files: list[str] | None = Field(default=None, description="Specific files to diff")
context_lines: int = Field(default=3, ge=0, description="Context lines")
class DiffResult(BaseModel):
"""Result of a diff operation."""
project_id: str = Field(..., description="Project ID")
base: str | None = Field(default=None, description="Base reference")
head: str | None = Field(default=None, description="Head reference")
files: list[dict[str, Any]] = Field(default_factory=list, description="File diffs")
total_additions: int = Field(default=0, description="Total lines added")
total_deletions: int = Field(default=0, description="Total lines removed")
files_changed: int = Field(default=0, description="Number of files changed")
class LogRequest(BaseModel):
"""Request for commit log."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
ref: str | None = Field(default=None, description="Reference to start from")
limit: int = Field(default=20, ge=1, le=100, description="Max commits to return")
skip: int = Field(default=0, ge=0, description="Commits to skip")
path: str | None = Field(default=None, description="Filter by path")
class LogResult(BaseModel):
"""Result of a log operation."""
project_id: str = Field(..., description="Project ID")
commits: list[dict[str, Any]] = Field(
default_factory=list, description="Commit history"
)
total_commits: int = Field(default=0, description="Total commits in range")
# PR Operations
class CreatePRRequest(BaseModel):
"""Request to create a pull request."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
title: str = Field(..., description="PR title")
body: str = Field(default="", description="PR description")
source_branch: str = Field(..., description="Source branch")
target_branch: str = Field(default="main", description="Target branch")
draft: bool = Field(default=False, description="Create as draft")
labels: list[str] = Field(default_factory=list, description="Labels to add")
assignees: list[str] = Field(default_factory=list, description="Assignees")
reviewers: list[str] = Field(default_factory=list, description="Reviewers")
class CreatePRResult(BaseModel):
"""Result of creating a pull request."""
success: bool = Field(..., description="Whether creation succeeded")
pr_number: int | None = Field(default=None, description="PR number")
pr_url: str | None = Field(default=None, description="PR URL")
error: str | None = Field(default=None, description="Error message if failed")
class GetPRRequest(BaseModel):
"""Request to get a pull request."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
pr_number: int = Field(..., description="PR number")
class GetPRResult(BaseModel):
"""Result of getting a pull request."""
success: bool = Field(..., description="Whether fetch succeeded")
pr: dict[str, Any] | None = Field(default=None, description="PR info")
error: str | None = Field(default=None, description="Error message if failed")
class ListPRsRequest(BaseModel):
"""Request to list pull requests."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
state: PRState | None = Field(default=None, description="Filter by state")
author: str | None = Field(default=None, description="Filter by author")
limit: int = Field(default=20, ge=1, le=100, description="Max PRs to return")
class ListPRsResult(BaseModel):
"""Result of listing pull requests."""
success: bool = Field(..., description="Whether list succeeded")
pull_requests: list[dict[str, Any]] = Field(default_factory=list, description="PRs")
total_count: int = Field(default=0, description="Total matching PRs")
error: str | None = Field(default=None, description="Error message if failed")
class MergePRRequest(BaseModel):
"""Request to merge a pull request."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
pr_number: int = Field(..., description="PR number")
merge_strategy: MergeStrategy = Field(
default=MergeStrategy.MERGE, description="Merge strategy"
)
commit_message: str | None = Field(
default=None, description="Custom merge commit message"
)
delete_branch: bool = Field(
default=True, description="Delete source branch after merge"
)
class MergePRResult(BaseModel):
"""Result of merging a pull request."""
success: bool = Field(..., description="Whether merge succeeded")
merge_commit_sha: str | None = Field(default=None, description="Merge commit SHA")
branch_deleted: bool = Field(
default=False, description="Whether branch was deleted"
)
error: str | None = Field(default=None, description="Error message if failed")
class UpdatePRRequest(BaseModel):
"""Request to update a pull request."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
pr_number: int = Field(..., description="PR number")
title: str | None = Field(default=None, description="New title")
body: str | None = Field(default=None, description="New description")
state: PRState | None = Field(default=None, description="New state")
labels: list[str] | None = Field(default=None, description="Replace labels")
assignees: list[str] | None = Field(default=None, description="Replace assignees")
class UpdatePRResult(BaseModel):
"""Result of updating a pull request."""
success: bool = Field(..., description="Whether update succeeded")
pr: dict[str, Any] | None = Field(default=None, description="Updated PR info")
error: str | None = Field(default=None, description="Error message if failed")
# Workspace Operations
class GetWorkspaceRequest(BaseModel):
"""Request to get or create workspace."""
project_id: str = Field(..., description="Project ID")
agent_id: str = Field(..., description="Agent ID making the request")
class GetWorkspaceResult(BaseModel):
"""Result of getting workspace."""
success: bool = Field(..., description="Whether operation succeeded")
workspace: dict[str, Any] | None = Field(default=None, description="Workspace info")
error: str | None = Field(default=None, description="Error message if failed")
class LockWorkspaceRequest(BaseModel):
"""Request to lock a workspace."""
project_id: str = Field(..., description="Project ID")
agent_id: str = Field(..., description="Agent ID requesting lock")
timeout: int = Field(
default=300, ge=10, le=3600, description="Lock timeout seconds"
)
class LockWorkspaceResult(BaseModel):
"""Result of locking workspace."""
success: bool = Field(..., description="Whether lock acquired")
lock_holder: str | None = Field(default=None, description="Current lock holder")
lock_expires: str | None = Field(
default=None, description="Lock expiry ISO timestamp"
)
error: str | None = Field(default=None, description="Error message if failed")
class UnlockWorkspaceRequest(BaseModel):
"""Request to unlock a workspace."""
project_id: str = Field(..., description="Project ID")
agent_id: str = Field(..., description="Agent ID releasing lock")
force: bool = Field(default=False, description="Force unlock (admin only)")
class UnlockWorkspaceResult(BaseModel):
"""Result of unlocking workspace."""
success: bool = Field(..., description="Whether unlock succeeded")
error: str | None = Field(default=None, description="Error message if failed")
# Health and Status
class HealthStatus(BaseModel):
"""Health status response."""
status: str = Field(..., description="Health status")
version: str = Field(..., description="Server version")
workspace_count: int = Field(default=0, description="Active workspaces")
gitea_connected: bool = Field(default=False, description="Gitea connectivity")
github_connected: bool = Field(default=False, description="GitHub connectivity")
gitlab_connected: bool = Field(default=False, description="GitLab connectivity")
redis_connected: bool = Field(default=False, description="Redis connectivity")
class ProviderStatus(BaseModel):
"""Provider connection status."""
provider: str = Field(..., description="Provider name")
connected: bool = Field(..., description="Connection status")
url: str | None = Field(default=None, description="Provider URL")
user: str | None = Field(default=None, description="Authenticated user")
error: str | None = Field(default=None, description="Error if not connected")

View File

@@ -0,0 +1,11 @@
"""
Git provider implementations.
Provides adapters for different git hosting platforms (Gitea, GitHub, GitLab).
"""
from .base import BaseProvider
from .gitea import GiteaProvider
from .github import GitHubProvider
__all__ = ["BaseProvider", "GiteaProvider", "GitHubProvider"]

View File

@@ -0,0 +1,376 @@
"""
Base provider interface for git hosting platforms.
Defines the abstract interface that all git providers must implement.
"""
from abc import ABC, abstractmethod
from typing import Any
from models import (
CreatePRResult,
GetPRResult,
ListPRsResult,
MergePRResult,
MergeStrategy,
PRState,
UpdatePRResult,
)
class BaseProvider(ABC):
"""
Abstract base class for git hosting providers.
All providers (Gitea, GitHub, GitLab) must implement this interface.
"""
@property
@abstractmethod
def name(self) -> str:
"""Return the provider name (e.g., 'gitea', 'github')."""
...
@abstractmethod
async def is_connected(self) -> bool:
"""Check if the provider is connected and authenticated."""
...
@abstractmethod
async def get_authenticated_user(self) -> str | None:
"""Get the username of the authenticated user."""
...
# Repository operations
@abstractmethod
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
"""
Get repository information.
Args:
owner: Repository owner/organization
repo: Repository name
Returns:
Repository info dict
"""
...
@abstractmethod
async def get_default_branch(self, owner: str, repo: str) -> str:
"""
Get the default branch for a repository.
Args:
owner: Repository owner/organization
repo: Repository name
Returns:
Default branch name
"""
...
# Pull Request operations
@abstractmethod
async def create_pr(
self,
owner: str,
repo: str,
title: str,
body: str,
source_branch: str,
target_branch: str,
draft: bool = False,
labels: list[str] | None = None,
assignees: list[str] | None = None,
reviewers: list[str] | None = None,
) -> CreatePRResult:
"""
Create a pull request.
Args:
owner: Repository owner
repo: Repository name
title: PR title
body: PR description
source_branch: Source branch name
target_branch: Target branch name
draft: Whether to create as draft
labels: Labels to add
assignees: Users to assign
reviewers: Users to request review from
Returns:
CreatePRResult with PR number and URL
"""
...
@abstractmethod
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
"""
Get a pull request by number.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
Returns:
GetPRResult with PR details
"""
...
@abstractmethod
async def list_prs(
self,
owner: str,
repo: str,
state: PRState | None = None,
author: str | None = None,
limit: int = 20,
) -> ListPRsResult:
"""
List pull requests.
Args:
owner: Repository owner
repo: Repository name
state: Filter by state (open, closed, merged)
author: Filter by author
limit: Maximum PRs to return
Returns:
ListPRsResult with list of PRs
"""
...
@abstractmethod
async def merge_pr(
self,
owner: str,
repo: str,
pr_number: int,
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
commit_message: str | None = None,
delete_branch: bool = True,
) -> MergePRResult:
"""
Merge a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
merge_strategy: Merge strategy to use
commit_message: Custom merge commit message
delete_branch: Whether to delete source branch
Returns:
MergePRResult with merge status
"""
...
@abstractmethod
async def update_pr(
self,
owner: str,
repo: str,
pr_number: int,
title: str | None = None,
body: str | None = None,
state: PRState | None = None,
labels: list[str] | None = None,
assignees: list[str] | None = None,
) -> UpdatePRResult:
"""
Update a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
title: New title
body: New description
state: New state (open, closed)
labels: Replace labels
assignees: Replace assignees
Returns:
UpdatePRResult with updated PR info
"""
...
@abstractmethod
async def close_pr(self, owner: str, repo: str, pr_number: int) -> UpdatePRResult:
"""
Close a pull request without merging.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
Returns:
UpdatePRResult with updated PR info
"""
...
# Branch operations via API (for operations that need to bypass local git)
@abstractmethod
async def delete_remote_branch(self, owner: str, repo: str, branch: str) -> bool:
"""
Delete a remote branch via API.
Args:
owner: Repository owner
repo: Repository name
branch: Branch name to delete
Returns:
True if deleted, False otherwise
"""
...
@abstractmethod
async def get_branch(
self, owner: str, repo: str, branch: str
) -> dict[str, Any] | None:
"""
Get branch information via API.
Args:
owner: Repository owner
repo: Repository name
branch: Branch name
Returns:
Branch info dict or None if not found
"""
...
# Comment operations
@abstractmethod
async def add_pr_comment(
self, owner: str, repo: str, pr_number: int, body: str
) -> dict[str, Any]:
"""
Add a comment to a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
body: Comment body
Returns:
Created comment info
"""
...
@abstractmethod
async def list_pr_comments(
self, owner: str, repo: str, pr_number: int
) -> list[dict[str, Any]]:
"""
List comments on a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
Returns:
List of comments
"""
...
# Label operations
@abstractmethod
async def add_labels(
self, owner: str, repo: str, pr_number: int, labels: list[str]
) -> list[str]:
"""
Add labels to a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
labels: Labels to add
Returns:
Updated list of labels
"""
...
@abstractmethod
async def remove_label(
self, owner: str, repo: str, pr_number: int, label: str
) -> list[str]:
"""
Remove a label from a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
label: Label to remove
Returns:
Updated list of labels
"""
...
# Reviewer operations
@abstractmethod
async def request_review(
self, owner: str, repo: str, pr_number: int, reviewers: list[str]
) -> list[str]:
"""
Request review from users.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
reviewers: Usernames to request review from
Returns:
List of reviewers requested
"""
...
# Utility methods
def parse_repo_url(self, repo_url: str) -> tuple[str, str]:
"""
Parse repository URL to extract owner and repo name.
Args:
repo_url: Repository URL (HTTPS or SSH)
Returns:
Tuple of (owner, repo)
Raises:
ValueError: If URL cannot be parsed
"""
import re
# Handle SSH URLs: git@host:owner/repo.git
ssh_match = re.match(r"git@[^:]+:([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
if ssh_match:
return ssh_match.group(1), ssh_match.group(2)
# Handle HTTPS URLs: https://host/owner/repo.git
https_match = re.match(r"https?://[^/]+/([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
if https_match:
return https_match.group(1), https_match.group(2)
raise ValueError(f"Unable to parse repository URL: {repo_url}")

View File

@@ -0,0 +1,723 @@
"""
Gitea provider implementation.
Implements the BaseProvider interface for Gitea API operations.
"""
import logging
from datetime import UTC, datetime
from typing import Any
import httpx
from config import Settings, get_settings
from exceptions import (
APIError,
AuthenticationError,
PRNotFoundError,
)
from models import (
CreatePRResult,
GetPRResult,
ListPRsResult,
MergePRResult,
MergeStrategy,
PRInfo,
PRState,
UpdatePRResult,
)
from .base import BaseProvider
logger = logging.getLogger(__name__)
class GiteaProvider(BaseProvider):
"""
Gitea API provider implementation.
Supports all PR operations, branch operations, and repository queries.
"""
def __init__(
self,
base_url: str | None = None,
token: str | None = None,
settings: Settings | None = None,
) -> None:
"""
Initialize Gitea provider.
Args:
base_url: Gitea server URL (e.g., https://gitea.example.com)
token: API token
settings: Optional settings override
"""
self.settings = settings or get_settings()
self.base_url = (base_url or self.settings.gitea_base_url).rstrip("/")
self.token = token or self.settings.gitea_token
self._client: httpx.AsyncClient | None = None
self._user: str | None = None
@property
def name(self) -> str:
"""Return the provider name."""
return "gitea"
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if self.token:
headers["Authorization"] = f"token {self.token}"
self._client = httpx.AsyncClient(
base_url=f"{self.base_url}/api/v1",
headers=headers,
timeout=30.0,
)
return self._client
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def _request(
self,
method: str,
path: str,
**kwargs: Any,
) -> Any:
"""
Make an API request.
Args:
method: HTTP method
path: API path
**kwargs: Additional request arguments
Returns:
Parsed JSON response
Raises:
APIError: On API errors
AuthenticationError: On auth failures
"""
client = await self._get_client()
try:
response = await client.request(method, path, **kwargs)
if response.status_code == 401:
raise AuthenticationError("gitea", "Invalid or expired token")
if response.status_code == 403:
raise AuthenticationError(
"gitea", "Insufficient permissions for this operation"
)
if response.status_code == 404:
return None
if response.status_code >= 400:
error_msg = response.text
try:
error_data = response.json()
error_msg = error_data.get("message", error_msg)
except Exception:
pass
raise APIError("gitea", response.status_code, error_msg)
if response.status_code == 204:
return None
return response.json()
except httpx.RequestError as e:
raise APIError("gitea", 0, f"Request failed: {e}")
async def is_connected(self) -> bool:
"""Check if connected to Gitea."""
if not self.base_url or not self.token:
return False
try:
result = await self._request("GET", "/user")
return result is not None
except Exception:
return False
async def get_authenticated_user(self) -> str | None:
"""Get the authenticated user's username."""
if self._user:
return self._user
try:
result = await self._request("GET", "/user")
if result:
self._user = result.get("login") or result.get("username")
return self._user
except Exception:
pass
return None
# Repository operations
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
"""Get repository information."""
result = await self._request("GET", f"/repos/{owner}/{repo}")
if result is None:
raise APIError("gitea", 404, f"Repository not found: {owner}/{repo}")
return result
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch for a repository."""
repo_info = await self.get_repo_info(owner, repo)
return repo_info.get("default_branch", "main")
# Pull Request operations
async def create_pr(
self,
owner: str,
repo: str,
title: str,
body: str,
source_branch: str,
target_branch: str,
draft: bool = False,
labels: list[str] | None = None,
assignees: list[str] | None = None,
reviewers: list[str] | None = None,
) -> CreatePRResult:
"""Create a pull request."""
try:
data: dict[str, Any] = {
"title": title,
"body": body,
"head": source_branch,
"base": target_branch,
}
# Note: Gitea doesn't have draft PR support in all versions
# Draft support was added in Gitea 1.14+
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls",
json=data,
)
if result is None:
return CreatePRResult(
success=False,
error="Failed to create pull request",
)
pr_number = result["number"]
# Add labels if specified
if labels:
await self.add_labels(owner, repo, pr_number, labels)
# Add assignees if specified (via issue update)
if assignees:
await self._request(
"PATCH",
f"/repos/{owner}/{repo}/issues/{pr_number}",
json={"assignees": assignees},
)
# Request reviewers if specified
if reviewers:
await self.request_review(owner, repo, pr_number, reviewers)
return CreatePRResult(
success=True,
pr_number=pr_number,
pr_url=result.get("html_url"),
)
except APIError as e:
return CreatePRResult(
success=False,
error=str(e),
)
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
"""Get a pull request by number."""
try:
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/pulls/{pr_number}",
)
if result is None:
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
pr_info = self._parse_pr(result)
return GetPRResult(
success=True,
pr=pr_info.to_dict(),
)
except PRNotFoundError:
return GetPRResult(
success=False,
error=f"Pull request #{pr_number} not found",
)
except APIError as e:
return GetPRResult(
success=False,
error=str(e),
)
async def list_prs(
self,
owner: str,
repo: str,
state: PRState | None = None,
author: str | None = None,
limit: int = 20,
) -> ListPRsResult:
"""List pull requests."""
try:
params: dict[str, Any] = {
"limit": limit,
}
if state:
# Gitea uses different state names
if state == PRState.OPEN:
params["state"] = "open"
elif state == PRState.CLOSED or state == PRState.MERGED:
params["state"] = "closed"
else:
params["state"] = "all"
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/pulls",
params=params,
)
if result is None:
return ListPRsResult(
success=True,
pull_requests=[],
total_count=0,
)
prs = []
for pr_data in result:
# Filter by author if specified
if author:
pr_author = pr_data.get("user", {}).get("login", "")
if pr_author.lower() != author.lower():
continue
# Filter merged PRs if looking specifically for merged
if state == PRState.MERGED:
if not pr_data.get("merged"):
continue
pr_info = self._parse_pr(pr_data)
prs.append(pr_info.to_dict())
return ListPRsResult(
success=True,
pull_requests=prs,
total_count=len(prs),
)
except APIError as e:
return ListPRsResult(
success=False,
error=str(e),
)
async def merge_pr(
self,
owner: str,
repo: str,
pr_number: int,
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
commit_message: str | None = None,
delete_branch: bool = True,
) -> MergePRResult:
"""Merge a pull request."""
try:
# Map merge strategy to Gitea's "Do" values
do_map = {
MergeStrategy.MERGE: "merge",
MergeStrategy.SQUASH: "squash",
MergeStrategy.REBASE: "rebase",
}
data: dict[str, Any] = {
"Do": do_map[merge_strategy],
"delete_branch_after_merge": delete_branch,
}
if commit_message:
data["MergeTitleField"] = commit_message.split("\n")[0]
if "\n" in commit_message:
data["MergeMessageField"] = "\n".join(
commit_message.split("\n")[1:]
)
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
json=data,
)
if result is None:
# Check if PR was actually merged
pr_result = await self.get_pr(owner, repo, pr_number)
if pr_result.success and pr_result.pr:
if pr_result.pr.get("state") == "merged":
return MergePRResult(
success=True,
branch_deleted=delete_branch,
)
return MergePRResult(
success=False,
error="Failed to merge pull request",
)
return MergePRResult(
success=True,
merge_commit_sha=result.get("sha"),
branch_deleted=delete_branch,
)
except APIError as e:
return MergePRResult(
success=False,
error=str(e),
)
async def update_pr(
self,
owner: str,
repo: str,
pr_number: int,
title: str | None = None,
body: str | None = None,
state: PRState | None = None,
labels: list[str] | None = None,
assignees: list[str] | None = None,
) -> UpdatePRResult:
"""Update a pull request."""
try:
data: dict[str, Any] = {}
if title is not None:
data["title"] = title
if body is not None:
data["body"] = body
if state is not None:
if state == PRState.OPEN:
data["state"] = "open"
elif state == PRState.CLOSED:
data["state"] = "closed"
# Update PR if there's data
if data:
await self._request(
"PATCH",
f"/repos/{owner}/{repo}/pulls/{pr_number}",
json=data,
)
# Update labels via issue endpoint
if labels is not None:
# First clear existing labels
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
)
# Then add new labels
if labels:
await self.add_labels(owner, repo, pr_number, labels)
# Update assignees via issue endpoint
if assignees is not None:
await self._request(
"PATCH",
f"/repos/{owner}/{repo}/issues/{pr_number}",
json={"assignees": assignees},
)
# Fetch updated PR
result = await self.get_pr(owner, repo, pr_number)
return UpdatePRResult(
success=result.success,
pr=result.pr,
error=result.error,
)
except APIError as e:
return UpdatePRResult(
success=False,
error=str(e),
)
async def close_pr(
self,
owner: str,
repo: str,
pr_number: int,
) -> UpdatePRResult:
"""Close a pull request without merging."""
return await self.update_pr(
owner,
repo,
pr_number,
state=PRState.CLOSED,
)
# Branch operations
async def delete_remote_branch(
self,
owner: str,
repo: str,
branch: str,
) -> bool:
"""Delete a remote branch."""
try:
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/branches/{branch}",
)
return True
except APIError:
return False
async def get_branch(
self,
owner: str,
repo: str,
branch: str,
) -> dict[str, Any] | None:
"""Get branch information."""
return await self._request(
"GET",
f"/repos/{owner}/{repo}/branches/{branch}",
)
# Comment operations
async def add_pr_comment(
self,
owner: str,
repo: str,
pr_number: int,
body: str,
) -> dict[str, Any]:
"""Add a comment to a pull request."""
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
json={"body": body},
)
return result or {}
async def list_pr_comments(
self,
owner: str,
repo: str,
pr_number: int,
) -> list[dict[str, Any]]:
"""List comments on a pull request."""
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
)
return result or []
# Label operations
async def add_labels(
self,
owner: str,
repo: str,
pr_number: int,
labels: list[str],
) -> list[str]:
"""Add labels to a pull request."""
# First, get or create label IDs
label_ids = []
for label_name in labels:
label_id = await self._get_or_create_label(owner, repo, label_name)
if label_id:
label_ids.append(label_id)
if label_ids:
await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
json={"labels": label_ids},
)
# Return current labels
issue = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}",
)
if issue:
return [lbl["name"] for lbl in issue.get("labels", [])]
return labels
async def remove_label(
self,
owner: str,
repo: str,
pr_number: int,
label: str,
) -> list[str]:
"""Remove a label from a pull request."""
# Get label ID
label_info = await self._request(
"GET",
f"/repos/{owner}/{repo}/labels?name={label}",
)
if label_info and len(label_info) > 0:
label_id = label_info[0]["id"]
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label_id}",
)
# Return remaining labels
issue = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}",
)
if issue:
return [lbl["name"] for lbl in issue.get("labels", [])]
return []
async def _get_or_create_label(
self,
owner: str,
repo: str,
label_name: str,
) -> int | None:
"""Get or create a label and return its ID."""
# Try to find existing label
labels = await self._request(
"GET",
f"/repos/{owner}/{repo}/labels",
)
if labels:
for label in labels:
if label["name"].lower() == label_name.lower():
return label["id"]
# Create new label with default color
try:
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/labels",
json={
"name": label_name,
"color": "#3B82F6", # Default blue
},
)
if result:
return result["id"]
except APIError:
pass
return None
# Reviewer operations
async def request_review(
self,
owner: str,
repo: str,
pr_number: int,
reviewers: list[str],
) -> list[str]:
"""Request review from users."""
await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
json={"reviewers": reviewers},
)
return reviewers
# Helper methods
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
"""Parse PR API response into PRInfo."""
# Parse dates
created_at = self._parse_datetime(data.get("created_at"))
updated_at = self._parse_datetime(data.get("updated_at"))
merged_at = self._parse_datetime(data.get("merged_at"))
closed_at = self._parse_datetime(data.get("closed_at"))
# Determine state
if data.get("merged"):
state = PRState.MERGED
elif data.get("state") == "closed":
state = PRState.CLOSED
else:
state = PRState.OPEN
# Extract labels
labels = [lbl["name"] for lbl in data.get("labels", [])]
# Extract assignees
assignees = [a["login"] for a in data.get("assignees", [])]
# Extract reviewers
reviewers = []
if "requested_reviewers" in data:
reviewers = [r["login"] for r in data["requested_reviewers"]]
return PRInfo(
number=data["number"],
title=data["title"],
body=data.get("body", ""),
state=state,
source_branch=data.get("head", {}).get("ref", ""),
target_branch=data.get("base", {}).get("ref", ""),
author=data.get("user", {}).get("login", ""),
created_at=created_at,
updated_at=updated_at,
merged_at=merged_at,
closed_at=closed_at,
url=data.get("html_url"),
labels=labels,
assignees=assignees,
reviewers=reviewers,
mergeable=data.get("mergeable"),
draft=data.get("draft", False),
)
def _parse_datetime(self, value: str | None) -> datetime:
"""Parse datetime string from API."""
if not value:
return datetime.now(UTC)
try:
# Handle Gitea's datetime format
if value.endswith("Z"):
value = value[:-1] + "+00:00"
return datetime.fromisoformat(value)
except ValueError:
return datetime.now(UTC)

View File

@@ -0,0 +1,675 @@
"""
GitHub provider implementation.
Implements the BaseProvider interface for GitHub API operations.
"""
import logging
from datetime import UTC, datetime
from typing import Any
import httpx
from config import Settings, get_settings
from exceptions import (
APIError,
AuthenticationError,
PRNotFoundError,
)
from models import (
CreatePRResult,
GetPRResult,
ListPRsResult,
MergePRResult,
MergeStrategy,
PRInfo,
PRState,
UpdatePRResult,
)
from .base import BaseProvider
logger = logging.getLogger(__name__)
class GitHubProvider(BaseProvider):
"""
GitHub API provider implementation.
Supports all PR operations, branch operations, and repository queries.
"""
def __init__(
self,
token: str | None = None,
settings: Settings | None = None,
) -> None:
"""
Initialize GitHub provider.
Args:
token: GitHub personal access token or fine-grained token
settings: Optional settings override
"""
self.settings = settings or get_settings()
self.token = token or self.settings.github_token
self._client: httpx.AsyncClient | None = None
self._user: str | None = None
@property
def name(self) -> str:
"""Return the provider name."""
return "github"
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
headers = {
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
self._client = httpx.AsyncClient(
base_url="https://api.github.com",
headers=headers,
timeout=30.0,
)
return self._client
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def _request(
self,
method: str,
path: str,
**kwargs: Any,
) -> Any:
"""
Make an API request.
Args:
method: HTTP method
path: API path
**kwargs: Additional request arguments
Returns:
Parsed JSON response
Raises:
APIError: On API errors
AuthenticationError: On auth failures
"""
client = await self._get_client()
try:
response = await client.request(method, path, **kwargs)
if response.status_code == 401:
raise AuthenticationError("github", "Invalid or expired token")
if response.status_code == 403:
# Check for rate limiting
if "rate limit" in response.text.lower():
raise APIError("github", 403, "GitHub API rate limit exceeded")
raise AuthenticationError(
"github", "Insufficient permissions for this operation"
)
if response.status_code == 404:
return None
if response.status_code >= 400:
error_msg = response.text
try:
error_data = response.json()
error_msg = error_data.get("message", error_msg)
except Exception:
pass
raise APIError("github", response.status_code, error_msg)
if response.status_code == 204:
return None
return response.json()
except httpx.RequestError as e:
raise APIError("github", 0, f"Request failed: {e}")
async def is_connected(self) -> bool:
"""Check if connected to GitHub."""
if not self.token:
return False
try:
result = await self._request("GET", "/user")
return result is not None
except Exception:
return False
async def get_authenticated_user(self) -> str | None:
"""Get the authenticated user's username."""
if self._user:
return self._user
try:
result = await self._request("GET", "/user")
if result:
self._user = result.get("login")
return self._user
except Exception:
pass
return None
# Repository operations
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
"""Get repository information."""
result = await self._request("GET", f"/repos/{owner}/{repo}")
if result is None:
raise APIError("github", 404, f"Repository not found: {owner}/{repo}")
return result
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch for a repository."""
repo_info = await self.get_repo_info(owner, repo)
return repo_info.get("default_branch", "main")
# Pull Request operations
async def create_pr(
self,
owner: str,
repo: str,
title: str,
body: str,
source_branch: str,
target_branch: str,
draft: bool = False,
labels: list[str] | None = None,
assignees: list[str] | None = None,
reviewers: list[str] | None = None,
) -> CreatePRResult:
"""Create a pull request."""
try:
data: dict[str, Any] = {
"title": title,
"body": body,
"head": source_branch,
"base": target_branch,
"draft": draft,
}
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls",
json=data,
)
if result is None:
return CreatePRResult(
success=False,
error="Failed to create pull request",
)
pr_number = result["number"]
# Add labels if specified
if labels:
await self.add_labels(owner, repo, pr_number, labels)
# Add assignees if specified
if assignees:
await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
json={"assignees": assignees},
)
# Request reviewers if specified
if reviewers:
await self.request_review(owner, repo, pr_number, reviewers)
return CreatePRResult(
success=True,
pr_number=pr_number,
pr_url=result.get("html_url"),
)
except APIError as e:
return CreatePRResult(
success=False,
error=str(e),
)
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
"""Get a pull request by number."""
try:
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/pulls/{pr_number}",
)
if result is None:
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
pr_info = self._parse_pr(result)
return GetPRResult(
success=True,
pr=pr_info.to_dict(),
)
except PRNotFoundError:
return GetPRResult(
success=False,
error=f"Pull request #{pr_number} not found",
)
except APIError as e:
return GetPRResult(
success=False,
error=str(e),
)
async def list_prs(
self,
owner: str,
repo: str,
state: PRState | None = None,
author: str | None = None,
limit: int = 20,
) -> ListPRsResult:
"""List pull requests."""
try:
params: dict[str, Any] = {
"per_page": min(limit, 100), # GitHub max is 100
}
if state:
# GitHub uses 'state' for open/closed only
# Merged PRs are closed PRs with merged_at set
if state == PRState.OPEN:
params["state"] = "open"
elif state in (PRState.CLOSED, PRState.MERGED):
params["state"] = "closed"
else:
params["state"] = "all"
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/pulls",
params=params,
)
if result is None:
return ListPRsResult(
success=True,
pull_requests=[],
total_count=0,
)
prs = []
for pr_data in result:
# Filter by author if specified
if author:
pr_author = pr_data.get("user", {}).get("login", "")
if pr_author.lower() != author.lower():
continue
# Filter merged PRs if looking specifically for merged
if state == PRState.MERGED:
if not pr_data.get("merged_at"):
continue
pr_info = self._parse_pr(pr_data)
prs.append(pr_info.to_dict())
return ListPRsResult(
success=True,
pull_requests=prs,
total_count=len(prs),
)
except APIError as e:
return ListPRsResult(
success=False,
error=str(e),
)
async def merge_pr(
self,
owner: str,
repo: str,
pr_number: int,
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
commit_message: str | None = None,
delete_branch: bool = True,
) -> MergePRResult:
"""Merge a pull request."""
try:
# Map merge strategy to GitHub's merge_method values
method_map = {
MergeStrategy.MERGE: "merge",
MergeStrategy.SQUASH: "squash",
MergeStrategy.REBASE: "rebase",
}
data: dict[str, Any] = {
"merge_method": method_map[merge_strategy],
}
if commit_message:
# For squash, commit_title and commit_message
# For merge, commit_title and commit_message
parts = commit_message.split("\n", 1)
data["commit_title"] = parts[0]
if len(parts) > 1:
data["commit_message"] = parts[1]
result = await self._request(
"PUT",
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
json=data,
)
if result is None:
return MergePRResult(
success=False,
error="Failed to merge pull request",
)
branch_deleted = False
# Delete branch if requested
if delete_branch and result.get("merged"):
# Get PR to find the branch name
pr_result = await self.get_pr(owner, repo, pr_number)
if pr_result.success and pr_result.pr:
source_branch = pr_result.pr.get("source_branch")
if source_branch:
branch_deleted = await self.delete_remote_branch(
owner, repo, source_branch
)
return MergePRResult(
success=True,
merge_commit_sha=result.get("sha"),
branch_deleted=branch_deleted,
)
except APIError as e:
return MergePRResult(
success=False,
error=str(e),
)
async def update_pr(
self,
owner: str,
repo: str,
pr_number: int,
title: str | None = None,
body: str | None = None,
state: PRState | None = None,
labels: list[str] | None = None,
assignees: list[str] | None = None,
) -> UpdatePRResult:
"""Update a pull request."""
try:
data: dict[str, Any] = {}
if title is not None:
data["title"] = title
if body is not None:
data["body"] = body
if state is not None:
if state == PRState.OPEN:
data["state"] = "open"
elif state == PRState.CLOSED:
data["state"] = "closed"
# Update PR if there's data
if data:
await self._request(
"PATCH",
f"/repos/{owner}/{repo}/pulls/{pr_number}",
json=data,
)
# Update labels via issue endpoint
if labels is not None:
await self._request(
"PUT",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
json={"labels": labels},
)
# Update assignees via issue endpoint
if assignees is not None:
# First remove all assignees
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
json={"assignees": []},
)
# Then add new ones
if assignees:
await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
json={"assignees": assignees},
)
# Fetch updated PR
result = await self.get_pr(owner, repo, pr_number)
return UpdatePRResult(
success=result.success,
pr=result.pr,
error=result.error,
)
except APIError as e:
return UpdatePRResult(
success=False,
error=str(e),
)
async def close_pr(
self,
owner: str,
repo: str,
pr_number: int,
) -> UpdatePRResult:
"""Close a pull request without merging."""
return await self.update_pr(
owner,
repo,
pr_number,
state=PRState.CLOSED,
)
# Branch operations
async def delete_remote_branch(
self,
owner: str,
repo: str,
branch: str,
) -> bool:
"""Delete a remote branch."""
try:
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/git/refs/heads/{branch}",
)
return True
except APIError:
return False
async def get_branch(
self,
owner: str,
repo: str,
branch: str,
) -> dict[str, Any] | None:
"""Get branch information."""
return await self._request(
"GET",
f"/repos/{owner}/{repo}/branches/{branch}",
)
# Comment operations
async def add_pr_comment(
self,
owner: str,
repo: str,
pr_number: int,
body: str,
) -> dict[str, Any]:
"""Add a comment to a pull request."""
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
json={"body": body},
)
return result or {}
async def list_pr_comments(
self,
owner: str,
repo: str,
pr_number: int,
) -> list[dict[str, Any]]:
"""List comments on a pull request."""
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
)
return result or []
# Label operations
async def add_labels(
self,
owner: str,
repo: str,
pr_number: int,
labels: list[str],
) -> list[str]:
"""Add labels to a pull request."""
# GitHub creates labels automatically if they don't exist (unlike Gitea)
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
json={"labels": labels},
)
if result:
return [lbl["name"] for lbl in result]
return labels
async def remove_label(
self,
owner: str,
repo: str,
pr_number: int,
label: str,
) -> list[str]:
"""Remove a label from a pull request."""
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label}",
)
# Return remaining labels
issue = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}",
)
if issue:
return [lbl["name"] for lbl in issue.get("labels", [])]
return []
# Reviewer operations
async def request_review(
self,
owner: str,
repo: str,
pr_number: int,
reviewers: list[str],
) -> list[str]:
"""Request review from users."""
await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
json={"reviewers": reviewers},
)
return reviewers
# Helper methods
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
"""Parse PR API response into PRInfo."""
# Parse dates
created_at = self._parse_datetime(data.get("created_at"))
updated_at = self._parse_datetime(data.get("updated_at"))
merged_at = self._parse_datetime(data.get("merged_at"))
closed_at = self._parse_datetime(data.get("closed_at"))
# Determine state
if data.get("merged_at"):
state = PRState.MERGED
elif data.get("state") == "closed":
state = PRState.CLOSED
else:
state = PRState.OPEN
# Extract labels
labels = [lbl["name"] for lbl in data.get("labels", [])]
# Extract assignees
assignees = [a["login"] for a in data.get("assignees", [])]
# Extract reviewers
reviewers = []
if "requested_reviewers" in data:
reviewers = [r["login"] for r in data["requested_reviewers"]]
return PRInfo(
number=data["number"],
title=data["title"],
body=data.get("body", "") or "",
state=state,
source_branch=data.get("head", {}).get("ref", ""),
target_branch=data.get("base", {}).get("ref", ""),
author=data.get("user", {}).get("login", ""),
created_at=created_at,
updated_at=updated_at,
merged_at=merged_at,
closed_at=closed_at,
url=data.get("html_url"),
labels=labels,
assignees=assignees,
reviewers=reviewers,
mergeable=data.get("mergeable"),
draft=data.get("draft", False),
)
def _parse_datetime(self, value: str | None) -> datetime:
"""Parse datetime string from API."""
if not value:
return datetime.now(UTC)
try:
# GitHub uses ISO 8601 format with Z suffix
if value.endswith("Z"):
value = value[:-1] + "+00:00"
return datetime.fromisoformat(value)
except ValueError:
return datetime.now(UTC)

View File

@@ -0,0 +1,120 @@
[project]
name = "syndarix-mcp-git-ops"
version = "0.1.0"
description = "Syndarix Git Operations MCP Server - Repository management, branching, commits, and PR workflows"
requires-python = ">=3.12"
dependencies = [
"fastmcp>=2.0.0",
"gitpython>=3.1.0",
"httpx>=0.27.0",
"redis>=5.0.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"uvicorn>=0.30.0",
"fastapi>=0.115.0",
"filelock>=3.15.0",
"aiofiles>=24.1.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
"pytest-cov>=5.0.0",
"fakeredis>=2.25.0",
"ruff>=0.8.0",
"mypy>=1.11.0",
"respx>=0.21.0",
]
[project.scripts]
git-ops = "server:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["."]
exclude = ["tests/", "*.md", "Dockerfile"]
[tool.hatch.build.targets.sdist]
include = ["*.py", "pyproject.toml"]
[tool.ruff]
target-version = "py312"
line-length = 88
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"ARG", # flake8-unused-arguments
"SIM", # flake8-simplify
"S", # flake8-bandit (security)
]
ignore = [
"E501", # line too long (handled by formatter)
"B008", # do not perform function calls in argument defaults
"B904", # raise from in except (too noisy)
"S104", # possible binding to all interfaces
"S110", # try-except-pass (intentional for optional operations)
"S603", # subprocess without shell=True (safe usage in git wrapper)
"S607", # starting a process with a partial path (git CLI)
"ARG002", # unused method arguments (for API compatibility)
"SIM102", # nested if statements (sometimes more readable)
"SIM105", # contextlib.suppress (sometimes more readable)
"SIM108", # ternary operator (sometimes more readable)
"SIM118", # dict.keys() (explicit is fine)
]
[tool.ruff.lint.isort]
known-first-party = ["config", "models", "exceptions", "git_wrapper", "workspace", "providers"]
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["S101", "ARG001", "S105", "S106", "S108", "F841", "B007"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"]
addopts = "-v --tb=short"
filterwarnings = [
"ignore::DeprecationWarning",
]
[tool.coverage.run]
source = ["."]
omit = ["tests/*", "conftest.py"]
branch = true
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise NotImplementedError",
"if TYPE_CHECKING:",
"if __name__ == .__main__.:",
]
fail_under = 78
show_missing = true
[tool.mypy]
python_version = "3.12"
warn_return_any = false
warn_unused_ignores = false
disallow_untyped_defs = true
ignore_missing_imports = true
plugins = ["pydantic.mypy"]
files = ["server.py", "config.py", "models.py", "exceptions.py", "git_wrapper.py", "workspace.py", "providers/"]
exclude = ["tests/"]
[[tool.mypy.overrides]]
module = "tests.*"
disallow_untyped_defs = false
ignore_errors = true

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
"""Tests for Git Operations MCP Server."""

View File

@@ -0,0 +1,297 @@
"""
Test configuration and fixtures for Git Operations MCP Server.
"""
import os
import shutil
import tempfile
from collections.abc import AsyncIterator, Iterator
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from git import Repo as GitRepo
# Set test environment
os.environ["IS_TEST"] = "true"
os.environ["GIT_OPS_WORKSPACE_BASE_PATH"] = "/tmp/test-workspaces"
os.environ["GIT_OPS_GITEA_BASE_URL"] = "https://gitea.test.com"
os.environ["GIT_OPS_GITEA_TOKEN"] = "test-token"
@pytest.fixture(scope="session", autouse=True)
def reset_settings_session():
"""Reset settings at start and end of test session."""
from config import reset_settings
reset_settings()
yield
reset_settings()
@pytest.fixture
def reset_settings():
"""Reset settings before each test that needs it."""
from config import reset_settings
reset_settings()
yield
reset_settings()
@pytest.fixture
def test_settings():
"""Get test settings."""
from config import Settings
return Settings(
workspace_base_path=Path("/tmp/test-workspaces"),
gitea_base_url="https://gitea.test.com",
gitea_token="test-token",
github_token="github-test-token",
git_author_name="Test Agent",
git_author_email="test@syndarix.ai",
enable_force_push=False,
debug=True,
)
@pytest.fixture
def temp_dir() -> Iterator[Path]:
"""Create a temporary directory for tests."""
temp_path = Path(tempfile.mkdtemp())
yield temp_path
if temp_path.exists():
shutil.rmtree(temp_path)
@pytest.fixture
def temp_workspace(temp_dir: Path) -> Path:
"""Create a temporary workspace directory."""
workspace = temp_dir / "workspace"
workspace.mkdir(parents=True, exist_ok=True)
return workspace
@pytest.fixture
def git_repo(temp_workspace: Path) -> GitRepo:
"""Create a git repository in the temp workspace."""
# Initialize with main branch (Git 2.28+)
repo = GitRepo.init(temp_workspace, initial_branch="main")
# Configure git
with repo.config_writer() as cw:
cw.set_value("user", "name", "Test User")
cw.set_value("user", "email", "test@example.com")
# Create initial commit
test_file = temp_workspace / "README.md"
test_file.write_text("# Test Repository\n")
repo.index.add(["README.md"])
repo.index.commit("Initial commit")
return repo
@pytest.fixture
def git_repo_with_remote(git_repo: GitRepo, temp_dir: Path) -> tuple[GitRepo, GitRepo]:
"""Create a git repository with a 'remote' (bare repo)."""
# Create bare repo as remote
remote_path = temp_dir / "remote.git"
remote_repo = GitRepo.init(remote_path, bare=True)
# Add remote to main repo
git_repo.create_remote("origin", str(remote_path))
# Push initial commit
git_repo.remotes.origin.push("main:main")
# Set up tracking
git_repo.heads.main.set_tracking_branch(git_repo.remotes.origin.refs.main)
return git_repo, remote_repo
@pytest.fixture
def workspace_manager(temp_dir: Path, test_settings):
"""Create a WorkspaceManager with test settings."""
from workspace import WorkspaceManager
test_settings.workspace_base_path = temp_dir / "workspaces"
return WorkspaceManager(test_settings)
@pytest.fixture
def git_wrapper(temp_workspace: Path, test_settings):
"""Create a GitWrapper for the temp workspace."""
from git_wrapper import GitWrapper
return GitWrapper(temp_workspace, test_settings)
@pytest.fixture
def git_wrapper_with_repo(git_repo: GitRepo, test_settings):
"""Create a GitWrapper for a repo that's already initialized."""
from git_wrapper import GitWrapper
return GitWrapper(Path(git_repo.working_dir), test_settings)
@pytest.fixture
def mock_gitea_provider():
"""Create a mock Gitea provider."""
provider = AsyncMock()
provider.name = "gitea"
provider.is_connected = AsyncMock(return_value=True)
provider.get_authenticated_user = AsyncMock(return_value="test-user")
provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
return provider
@pytest.fixture
def mock_httpx_client():
"""Create a mock httpx client for provider tests."""
from unittest.mock import AsyncMock
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = MagicMock(return_value={})
mock_response.text = ""
mock_client = AsyncMock()
mock_client.request = AsyncMock(return_value=mock_response)
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.patch = AsyncMock(return_value=mock_response)
mock_client.delete = AsyncMock(return_value=mock_response)
return mock_client
@pytest.fixture
async def gitea_provider(test_settings, mock_httpx_client):
"""Create a GiteaProvider with mocked HTTP client."""
from providers.gitea import GiteaProvider
provider = GiteaProvider(
base_url=test_settings.gitea_base_url,
token=test_settings.gitea_token,
settings=test_settings,
)
provider._client = mock_httpx_client
yield provider
await provider.close()
@pytest.fixture
def sample_pr_data():
"""Sample PR data from Gitea API."""
return {
"number": 42,
"title": "Test PR",
"body": "This is a test pull request",
"state": "open",
"head": {"ref": "feature-branch"},
"base": {"ref": "main"},
"user": {"login": "test-user"},
"created_at": "2024-01-15T10:00:00Z",
"updated_at": "2024-01-15T12:00:00Z",
"merged_at": None,
"closed_at": None,
"html_url": "https://gitea.test.com/owner/repo/pull/42",
"labels": [{"name": "enhancement"}],
"assignees": [{"login": "assignee1"}],
"requested_reviewers": [{"login": "reviewer1"}],
"mergeable": True,
"draft": False,
}
@pytest.fixture
def sample_commit_data():
"""Sample commit data."""
return {
"sha": "abc123def456",
"short_sha": "abc123d",
"message": "Test commit message",
"author": {
"name": "Test Author",
"email": "author@test.com",
"date": "2024-01-15T10:00:00Z",
},
"committer": {
"name": "Test Committer",
"email": "committer@test.com",
"date": "2024-01-15T10:00:00Z",
},
}
@pytest.fixture
def mock_fastapi_app():
"""Create a test FastAPI app."""
from fastapi import FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/health")
def health():
return {"status": "healthy"}
return TestClient(app)
# Async fixtures
@pytest.fixture
async def async_workspace_manager(temp_dir: Path, test_settings) -> AsyncIterator:
"""Async fixture for workspace manager."""
from workspace import WorkspaceManager
test_settings.workspace_base_path = temp_dir / "workspaces"
manager = WorkspaceManager(test_settings)
yield manager
# Test data fixtures
@pytest.fixture
def valid_project_id() -> str:
"""Valid project ID for tests."""
return "test-project-123"
@pytest.fixture
def valid_agent_id() -> str:
"""Valid agent ID for tests."""
return "agent-456"
@pytest.fixture
def invalid_ids() -> list[str]:
"""Invalid IDs for validation tests."""
return [
"",
" ",
"a" * 200, # Too long
"test@invalid", # Invalid character
"test!invalid",
"../path/traversal",
]
@pytest.fixture
def sample_repo_url() -> str:
"""Sample repository URL."""
return "https://gitea.test.com/owner/repo.git"
@pytest.fixture
def sample_ssh_repo_url() -> str:
"""Sample SSH repository URL."""
return "git@gitea.test.com:owner/repo.git"

View File

@@ -0,0 +1,440 @@
"""
Tests for FastAPI endpoints.
Tests health check and MCP JSON-RPC endpoints.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
class TestHealthEndpoint:
"""Tests for health check endpoint."""
@pytest.mark.asyncio
async def test_health_no_providers(self):
"""Test health check when no providers configured."""
from httpx import ASGITransport, AsyncClient
from server import app
with (
patch("server._settings", MagicMock()),
patch("server._workspace_manager", None),
patch("server._gitea_provider", None),
patch("server._github_provider", None),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] in ["healthy", "degraded"]
assert data["service"] == "git-ops"
@pytest.mark.asyncio
async def test_health_with_gitea_connected(self):
"""Test health check with Gitea provider connected."""
from httpx import ASGITransport, AsyncClient
from server import app
mock_gitea = AsyncMock()
mock_gitea.is_connected = AsyncMock(return_value=True)
mock_gitea.get_authenticated_user = AsyncMock(return_value="test-user")
with (
patch("server._settings", MagicMock()),
patch("server._workspace_manager", None),
patch("server._gitea_provider", mock_gitea),
patch("server._github_provider", None),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert "gitea" in data["dependencies"]
@pytest.mark.asyncio
async def test_health_with_gitea_not_connected(self):
"""Test health check when Gitea is not connected."""
from httpx import ASGITransport, AsyncClient
from server import app
mock_gitea = AsyncMock()
mock_gitea.is_connected = AsyncMock(return_value=False)
with (
patch("server._settings", MagicMock()),
patch("server._workspace_manager", None),
patch("server._gitea_provider", mock_gitea),
patch("server._github_provider", None),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "degraded"
@pytest.mark.asyncio
async def test_health_with_gitea_error(self):
"""Test health check when Gitea throws error."""
from httpx import ASGITransport, AsyncClient
from server import app
mock_gitea = AsyncMock()
mock_gitea.is_connected = AsyncMock(side_effect=Exception("Connection failed"))
with (
patch("server._settings", MagicMock()),
patch("server._workspace_manager", None),
patch("server._gitea_provider", mock_gitea),
patch("server._github_provider", None),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "degraded"
@pytest.mark.asyncio
async def test_health_with_github_connected(self):
"""Test health check with GitHub provider connected."""
from httpx import ASGITransport, AsyncClient
from server import app
mock_github = AsyncMock()
mock_github.is_connected = AsyncMock(return_value=True)
mock_github.get_authenticated_user = AsyncMock(return_value="github-user")
with (
patch("server._settings", MagicMock()),
patch("server._workspace_manager", None),
patch("server._gitea_provider", None),
patch("server._github_provider", mock_github),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert "github" in data["dependencies"]
@pytest.mark.asyncio
async def test_health_with_github_not_connected(self):
"""Test health check when GitHub is not connected."""
from httpx import ASGITransport, AsyncClient
from server import app
mock_github = AsyncMock()
mock_github.is_connected = AsyncMock(return_value=False)
with (
patch("server._settings", MagicMock()),
patch("server._workspace_manager", None),
patch("server._gitea_provider", None),
patch("server._github_provider", mock_github),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "degraded"
@pytest.mark.asyncio
async def test_health_with_github_error(self):
"""Test health check when GitHub throws error."""
from httpx import ASGITransport, AsyncClient
from server import app
mock_github = AsyncMock()
mock_github.is_connected = AsyncMock(side_effect=Exception("Auth failed"))
with (
patch("server._settings", MagicMock()),
patch("server._workspace_manager", None),
patch("server._gitea_provider", None),
patch("server._github_provider", mock_github),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "degraded"
@pytest.mark.asyncio
async def test_health_with_workspace_manager(self):
"""Test health check with workspace manager."""
from pathlib import Path
from httpx import ASGITransport, AsyncClient
from server import app
mock_manager = AsyncMock()
mock_manager.base_path = Path("/tmp/workspaces")
mock_manager.list_workspaces = AsyncMock(return_value=[])
with (
patch("server._settings", MagicMock()),
patch("server._workspace_manager", mock_manager),
patch("server._gitea_provider", None),
patch("server._github_provider", None),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert "workspace" in data["dependencies"]
@pytest.mark.asyncio
async def test_health_workspace_error(self):
"""Test health check when workspace manager throws error."""
from pathlib import Path
from httpx import ASGITransport, AsyncClient
from server import app
mock_manager = AsyncMock()
mock_manager.base_path = Path("/tmp/workspaces")
mock_manager.list_workspaces = AsyncMock(side_effect=Exception("Disk full"))
with (
patch("server._settings", MagicMock()),
patch("server._workspace_manager", mock_manager),
patch("server._gitea_provider", None),
patch("server._github_provider", None),
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "degraded"
class TestMCPToolsEndpoint:
"""Tests for MCP tools list endpoint."""
@pytest.mark.asyncio
async def test_list_mcp_tools(self):
"""Test listing MCP tools."""
from httpx import ASGITransport, AsyncClient
from server import app
with patch("server._settings", MagicMock()):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.get("/mcp/tools")
assert response.status_code == 200
data = response.json()
assert "tools" in data
class TestMCPRPCEndpoint:
"""Tests for MCP JSON-RPC endpoint."""
@pytest.mark.asyncio
async def test_mcp_rpc_invalid_json(self):
"""Test RPC with invalid JSON."""
from httpx import ASGITransport, AsyncClient
from server import app
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.post(
"/mcp",
content="not valid json",
headers={"Content-Type": "application/json"},
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32700
@pytest.mark.asyncio
async def test_mcp_rpc_invalid_jsonrpc(self):
"""Test RPC with invalid jsonrpc version."""
from httpx import ASGITransport, AsyncClient
from server import app
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.post(
"/mcp", json={"jsonrpc": "1.0", "method": "test", "id": 1}
)
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32600
@pytest.mark.asyncio
async def test_mcp_rpc_missing_method(self):
"""Test RPC with missing method."""
from httpx import ASGITransport, AsyncClient
from server import app
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.post("/mcp", json={"jsonrpc": "2.0", "id": 1})
assert response.status_code == 400
data = response.json()
assert data["error"]["code"] == -32600
@pytest.mark.asyncio
async def test_mcp_rpc_method_not_found(self):
"""Test RPC with unknown method."""
from httpx import ASGITransport, AsyncClient
from server import app
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
response = await client.post(
"/mcp",
json={
"jsonrpc": "2.0",
"method": "unknown_method",
"params": {},
"id": 1,
},
)
assert response.status_code == 404
data = response.json()
assert data["error"]["code"] == -32601
class TestTypeSchemaConversion:
"""Tests for type to JSON schema conversion."""
def test_python_type_to_json_schema_str(self):
"""Test converting str type to JSON schema."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(str)
assert result == {"type": "string"}
def test_python_type_to_json_schema_int(self):
"""Test converting int type to JSON schema."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(int)
assert result == {"type": "integer"}
def test_python_type_to_json_schema_float(self):
"""Test converting float type to JSON schema."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(float)
assert result == {"type": "number"}
def test_python_type_to_json_schema_bool(self):
"""Test converting bool type to JSON schema."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(bool)
assert result == {"type": "boolean"}
def test_python_type_to_json_schema_none(self):
"""Test converting NoneType to JSON schema."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(type(None))
assert result == {"type": "null"}
def test_python_type_to_json_schema_list(self):
"""Test converting list type to JSON schema."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(list[str])
assert result["type"] == "array"
def test_python_type_to_json_schema_dict(self):
"""Test converting dict type to JSON schema."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(dict[str, int])
assert result == {"type": "object"}
def test_python_type_to_json_schema_optional(self):
"""Test converting Optional type to JSON schema."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(str | None)
# The function returns object type for complex union types
assert "type" in result
class TestToolSchema:
"""Tests for tool schema extraction."""
def test_get_tool_schema_simple(self):
"""Test getting schema from simple function."""
from server import _get_tool_schema
def simple_func(name: str, count: int) -> str:
return f"{name}: {count}"
result = _get_tool_schema(simple_func)
assert "properties" in result
assert "name" in result["properties"]
assert "count" in result["properties"]
def test_register_and_get_tool(self):
"""Test registering a tool."""
from server import _register_tool, _tool_registry
async def test_tool(x: str) -> str:
"""A test tool."""
return x
_register_tool("test_tool", test_tool, "Test description")
assert "test_tool" in _tool_registry
assert _tool_registry["test_tool"]["description"] == "Test description"
# Clean up
del _tool_registry["test_tool"]

View File

@@ -0,0 +1,943 @@
"""
Tests for the git_wrapper module.
"""
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from git import GitCommandError
from exceptions import (
BranchExistsError,
BranchNotFoundError,
CheckoutError,
CloneError,
CommitError,
GitError,
PullError,
PushError,
)
from git_wrapper import GitWrapper, run_in_executor
from models import FileChangeType
class TestGitWrapperInit:
"""Tests for GitWrapper initialization."""
def test_init_with_valid_path(self, temp_workspace, test_settings):
"""Test initialization with a valid path."""
wrapper = GitWrapper(temp_workspace, test_settings)
assert wrapper.workspace_path == temp_workspace
assert wrapper.settings == test_settings
def test_repo_property_raises_on_non_git(self, temp_workspace, test_settings):
"""Test that accessing repo on non-git dir raises error."""
wrapper = GitWrapper(temp_workspace, test_settings)
with pytest.raises(GitError, match="Not a git repository"):
_ = wrapper.repo
def test_repo_property_works_on_git_dir(self, git_repo, test_settings):
"""Test that repo property works for git directory."""
wrapper = GitWrapper(Path(git_repo.working_dir), test_settings)
assert wrapper.repo is not None
assert wrapper.repo.head is not None
class TestGitWrapperStatus:
"""Tests for git status operations."""
@pytest.mark.asyncio
async def test_status_clean_repo(self, git_wrapper_with_repo):
"""Test status on a clean repository."""
result = await git_wrapper_with_repo.status()
assert result.branch == "main"
assert result.is_clean is True
assert len(result.staged) == 0
assert len(result.unstaged) == 0
assert len(result.untracked) == 0
@pytest.mark.asyncio
async def test_status_with_untracked(self, git_wrapper_with_repo, git_repo):
"""Test status with untracked files."""
# Create untracked file
untracked_file = Path(git_repo.working_dir) / "untracked.txt"
untracked_file.write_text("untracked content")
result = await git_wrapper_with_repo.status()
assert result.is_clean is False
assert "untracked.txt" in result.untracked
@pytest.mark.asyncio
async def test_status_with_modified(self, git_wrapper_with_repo, git_repo):
"""Test status with modified files."""
# Modify existing file
readme = Path(git_repo.working_dir) / "README.md"
readme.write_text("# Modified content\n")
result = await git_wrapper_with_repo.status()
assert result.is_clean is False
assert len(result.unstaged) > 0
@pytest.mark.asyncio
async def test_status_with_staged(self, git_wrapper_with_repo, git_repo):
"""Test status with staged changes."""
# Create and stage a file
new_file = Path(git_repo.working_dir) / "staged.txt"
new_file.write_text("staged content")
git_repo.index.add(["staged.txt"])
result = await git_wrapper_with_repo.status()
assert result.is_clean is False
assert len(result.staged) > 0
@pytest.mark.asyncio
async def test_status_exclude_untracked(self, git_wrapper_with_repo, git_repo):
"""Test status without untracked files."""
untracked_file = Path(git_repo.working_dir) / "untracked.txt"
untracked_file.write_text("untracked")
result = await git_wrapper_with_repo.status(include_untracked=False)
assert len(result.untracked) == 0
class TestGitWrapperBranch:
"""Tests for branch operations."""
@pytest.mark.asyncio
async def test_create_branch(self, git_wrapper_with_repo):
"""Test creating a new branch."""
result = await git_wrapper_with_repo.create_branch("feature-test")
assert result.success is True
assert result.branch == "feature-test"
assert result.is_current is True
@pytest.mark.asyncio
async def test_create_branch_without_checkout(self, git_wrapper_with_repo):
"""Test creating branch without checkout."""
result = await git_wrapper_with_repo.create_branch(
"feature-no-checkout", checkout=False
)
assert result.success is True
assert result.branch == "feature-no-checkout"
assert result.is_current is False
@pytest.mark.asyncio
async def test_create_branch_exists_error(self, git_wrapper_with_repo):
"""Test error when branch already exists."""
await git_wrapper_with_repo.create_branch("existing-branch", checkout=False)
with pytest.raises(BranchExistsError):
await git_wrapper_with_repo.create_branch("existing-branch")
@pytest.mark.asyncio
async def test_delete_branch(self, git_wrapper_with_repo):
"""Test deleting a branch."""
# Create branch first
await git_wrapper_with_repo.create_branch("to-delete", checkout=False)
# Delete it
result = await git_wrapper_with_repo.delete_branch("to-delete")
assert result.success is True
assert result.branch == "to-delete"
@pytest.mark.asyncio
async def test_delete_branch_not_found(self, git_wrapper_with_repo):
"""Test error when deleting non-existent branch."""
with pytest.raises(BranchNotFoundError):
await git_wrapper_with_repo.delete_branch("nonexistent")
@pytest.mark.asyncio
async def test_delete_current_branch_error(self, git_wrapper_with_repo):
"""Test error when deleting current branch."""
with pytest.raises(GitError, match="Cannot delete current branch"):
await git_wrapper_with_repo.delete_branch("main")
@pytest.mark.asyncio
async def test_list_branches(self, git_wrapper_with_repo):
"""Test listing branches."""
# Create some branches
await git_wrapper_with_repo.create_branch("branch-a", checkout=False)
await git_wrapper_with_repo.create_branch("branch-b", checkout=False)
result = await git_wrapper_with_repo.list_branches()
assert result.current_branch == "main"
branch_names = [b["name"] for b in result.local_branches]
assert "main" in branch_names
assert "branch-a" in branch_names
assert "branch-b" in branch_names
class TestGitWrapperCheckout:
"""Tests for checkout operations."""
@pytest.mark.asyncio
async def test_checkout_existing_branch(self, git_wrapper_with_repo):
"""Test checkout of existing branch."""
# Create branch first
await git_wrapper_with_repo.create_branch("test-branch", checkout=False)
result = await git_wrapper_with_repo.checkout("test-branch")
assert result.success is True
assert result.ref == "test-branch"
@pytest.mark.asyncio
async def test_checkout_create_new(self, git_wrapper_with_repo):
"""Test checkout with branch creation."""
result = await git_wrapper_with_repo.checkout("new-branch", create_branch=True)
assert result.success is True
assert result.ref == "new-branch"
@pytest.mark.asyncio
async def test_checkout_nonexistent_error(self, git_wrapper_with_repo):
"""Test error when checking out non-existent ref."""
with pytest.raises(CheckoutError):
await git_wrapper_with_repo.checkout("nonexistent-branch")
class TestGitWrapperCommit:
"""Tests for commit operations."""
@pytest.mark.asyncio
async def test_commit_staged_changes(self, git_wrapper_with_repo, git_repo):
"""Test committing staged changes."""
# Create and stage a file
new_file = Path(git_repo.working_dir) / "newfile.txt"
new_file.write_text("new content")
git_repo.index.add(["newfile.txt"])
result = await git_wrapper_with_repo.commit("Add new file")
assert result.success is True
assert result.message == "Add new file"
assert result.files_changed == 1
@pytest.mark.asyncio
async def test_commit_all_changes(self, git_wrapper_with_repo, git_repo):
"""Test committing all changes (auto-stage)."""
# Create a file without staging
new_file = Path(git_repo.working_dir) / "unstaged.txt"
new_file.write_text("content")
result = await git_wrapper_with_repo.commit("Commit unstaged")
assert result.success is True
@pytest.mark.asyncio
async def test_commit_nothing_to_commit(self, git_wrapper_with_repo):
"""Test error when nothing to commit."""
with pytest.raises(CommitError, match="Nothing to commit"):
await git_wrapper_with_repo.commit("Empty commit")
@pytest.mark.asyncio
async def test_commit_with_author(self, git_wrapper_with_repo, git_repo):
"""Test commit with custom author."""
new_file = Path(git_repo.working_dir) / "authored.txt"
new_file.write_text("authored content")
result = await git_wrapper_with_repo.commit(
"Custom author commit",
author_name="Custom Author",
author_email="custom@test.com",
)
assert result.success is True
class TestGitWrapperDiff:
"""Tests for diff operations."""
@pytest.mark.asyncio
async def test_diff_no_changes(self, git_wrapper_with_repo):
"""Test diff with no changes."""
result = await git_wrapper_with_repo.diff()
assert result.files_changed == 0
assert result.total_additions == 0
assert result.total_deletions == 0
@pytest.mark.asyncio
async def test_diff_with_changes(self, git_wrapper_with_repo, git_repo):
"""Test diff with modified files."""
# Modify a file
readme = Path(git_repo.working_dir) / "README.md"
readme.write_text("# Modified\nNew line\n")
result = await git_wrapper_with_repo.diff()
assert result.files_changed > 0
class TestGitWrapperLog:
"""Tests for log operations."""
@pytest.mark.asyncio
async def test_log_basic(self, git_wrapper_with_repo):
"""Test basic log."""
result = await git_wrapper_with_repo.log()
assert result.total_commits > 0
assert len(result.commits) > 0
@pytest.mark.asyncio
async def test_log_with_limit(self, git_wrapper_with_repo, git_repo):
"""Test log with limit."""
# Create more commits
for i in range(5):
file_path = Path(git_repo.working_dir) / f"file{i}.txt"
file_path.write_text(f"content {i}")
git_repo.index.add([f"file{i}.txt"])
git_repo.index.commit(f"Commit {i}")
result = await git_wrapper_with_repo.log(limit=3)
assert len(result.commits) == 3
@pytest.mark.asyncio
async def test_log_commit_info(self, git_wrapper_with_repo):
"""Test that log returns proper commit info."""
result = await git_wrapper_with_repo.log(limit=1)
commit = result.commits[0]
assert "sha" in commit
assert "message" in commit
assert "author_name" in commit
assert "author_email" in commit
class TestGitWrapperUtilities:
"""Tests for utility methods."""
@pytest.mark.asyncio
async def test_is_valid_ref_true(self, git_wrapper_with_repo):
"""Test valid ref detection."""
is_valid = await git_wrapper_with_repo.is_valid_ref("main")
assert is_valid is True
@pytest.mark.asyncio
async def test_is_valid_ref_false(self, git_wrapper_with_repo):
"""Test invalid ref detection."""
is_valid = await git_wrapper_with_repo.is_valid_ref("nonexistent")
assert is_valid is False
def test_diff_to_change_type(self, git_wrapper_with_repo):
"""Test change type conversion."""
wrapper = git_wrapper_with_repo
assert wrapper._diff_to_change_type("A") == FileChangeType.ADDED
assert wrapper._diff_to_change_type("M") == FileChangeType.MODIFIED
assert wrapper._diff_to_change_type("D") == FileChangeType.DELETED
assert wrapper._diff_to_change_type("R") == FileChangeType.RENAMED
class TestGitWrapperStage:
"""Tests for staging operations."""
@pytest.mark.asyncio
async def test_stage_specific_files(self, git_wrapper_with_repo, git_repo):
"""Test staging specific files."""
# Create files
file1 = Path(git_repo.working_dir) / "file1.txt"
file2 = Path(git_repo.working_dir) / "file2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
count = await git_wrapper_with_repo.stage(["file1.txt"])
assert count == 1
@pytest.mark.asyncio
async def test_stage_all(self, git_wrapper_with_repo, git_repo):
"""Test staging all files."""
file1 = Path(git_repo.working_dir) / "all1.txt"
file2 = Path(git_repo.working_dir) / "all2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
count = await git_wrapper_with_repo.stage()
assert count >= 2
@pytest.mark.asyncio
async def test_unstage_files(self, git_wrapper_with_repo, git_repo):
"""Test unstaging files."""
# Create and stage file
file1 = Path(git_repo.working_dir) / "unstage.txt"
file1.write_text("to unstage")
git_repo.index.add(["unstage.txt"])
count = await git_wrapper_with_repo.unstage()
assert count >= 1
class TestGitWrapperReset:
"""Tests for reset operations."""
@pytest.mark.asyncio
async def test_reset_soft(self, git_wrapper_with_repo, git_repo):
"""Test soft reset."""
# Create a commit to reset
file1 = Path(git_repo.working_dir) / "reset_soft.txt"
file1.write_text("content")
git_repo.index.add(["reset_soft.txt"])
git_repo.index.commit("Commit to reset")
result = await git_wrapper_with_repo.reset("HEAD~1", mode="soft")
assert result is True
@pytest.mark.asyncio
async def test_reset_mixed(self, git_wrapper_with_repo, git_repo):
"""Test mixed reset (default)."""
file1 = Path(git_repo.working_dir) / "reset_mixed.txt"
file1.write_text("content")
git_repo.index.add(["reset_mixed.txt"])
git_repo.index.commit("Commit to reset")
result = await git_wrapper_with_repo.reset("HEAD~1", mode="mixed")
assert result is True
@pytest.mark.asyncio
async def test_reset_invalid_mode(self, git_wrapper_with_repo):
"""Test error on invalid reset mode."""
with pytest.raises(GitError, match="Invalid reset mode"):
await git_wrapper_with_repo.reset("HEAD", mode="invalid")
class TestGitWrapperStash:
"""Tests for stash operations."""
@pytest.mark.asyncio
async def test_stash_changes(self, git_wrapper_with_repo, git_repo):
"""Test stashing changes."""
# Make changes
readme = Path(git_repo.working_dir) / "README.md"
readme.write_text("Modified for stash")
result = await git_wrapper_with_repo.stash("Test stash")
# Result should be stash ref or None if nothing to stash
# (depends on whether changes were already staged)
assert result is None or result.startswith("stash@")
@pytest.mark.asyncio
async def test_stash_nothing(self, git_wrapper_with_repo):
"""Test stash with no changes."""
result = await git_wrapper_with_repo.stash()
assert result is None
@pytest.mark.asyncio
async def test_stash_pop(self, git_wrapper_with_repo, git_repo):
"""Test popping a stash."""
# Make changes and stash them
readme = Path(git_repo.working_dir) / "README.md"
original_content = readme.read_text()
readme.write_text("Modified for stash pop test")
git_repo.index.add(["README.md"])
stash_ref = await git_wrapper_with_repo.stash("Test stash for pop")
if stash_ref:
# Pop the stash
result = await git_wrapper_with_repo.stash_pop()
assert result is True
class TestGitWrapperRepoProperty:
"""Tests for repo property edge cases."""
def test_repo_property_path_not_exists(self, test_settings):
"""Test that accessing repo on non-existent path raises error."""
wrapper = GitWrapper(
Path("/nonexistent/path/that/does/not/exist"), test_settings
)
with pytest.raises(GitError, match="Path does not exist"):
_ = wrapper.repo
def test_refresh_repo(self, git_wrapper_with_repo):
"""Test _refresh_repo clears cached repo."""
# Access repo to cache it
_ = git_wrapper_with_repo.repo
assert git_wrapper_with_repo._repo is not None
# Refresh should clear it
git_wrapper_with_repo._refresh_repo()
assert git_wrapper_with_repo._repo is None
class TestGitWrapperBranchAdvanced:
"""Advanced tests for branch operations."""
@pytest.mark.asyncio
async def test_create_branch_from_ref(self, git_wrapper_with_repo, git_repo):
"""Test creating branch from specific ref."""
# Get current HEAD SHA
head_sha = git_repo.head.commit.hexsha
result = await git_wrapper_with_repo.create_branch(
"feature-from-ref",
from_ref=head_sha,
checkout=False,
)
assert result.success is True
assert result.branch == "feature-from-ref"
@pytest.mark.asyncio
async def test_delete_branch_force(self, git_wrapper_with_repo, git_repo):
"""Test force deleting a branch."""
# Create branch and add unmerged commit
await git_wrapper_with_repo.create_branch("unmerged-branch", checkout=True)
new_file = Path(git_repo.working_dir) / "unmerged.txt"
new_file.write_text("unmerged content")
git_repo.index.add(["unmerged.txt"])
git_repo.index.commit("Unmerged commit")
# Switch back to main
await git_wrapper_with_repo.checkout("main")
# Force delete
result = await git_wrapper_with_repo.delete_branch(
"unmerged-branch", force=True
)
assert result.success is True
class TestGitWrapperListBranchesRemote:
"""Tests for listing remote branches."""
@pytest.mark.asyncio
async def test_list_branches_with_remote(self, git_wrapper_with_repo):
"""Test listing branches including remote."""
# Even without remotes, this should work
result = await git_wrapper_with_repo.list_branches(include_remote=True)
assert result.current_branch == "main"
# Remote branches list should be empty for local repo
assert len(result.remote_branches) == 0
class TestGitWrapperCheckoutAdvanced:
"""Advanced tests for checkout operations."""
@pytest.mark.asyncio
async def test_checkout_create_existing_error(self, git_wrapper_with_repo):
"""Test error when creating branch that already exists."""
with pytest.raises(BranchExistsError):
await git_wrapper_with_repo.checkout("main", create_branch=True)
@pytest.mark.asyncio
async def test_checkout_force(self, git_wrapper_with_repo, git_repo):
"""Test force checkout discards local changes."""
# Create branch
await git_wrapper_with_repo.create_branch("force-test", checkout=False)
# Make local changes
readme = Path(git_repo.working_dir) / "README.md"
readme.write_text("local changes")
# Force checkout should work
result = await git_wrapper_with_repo.checkout("force-test", force=True)
assert result.success is True
class TestGitWrapperCommitAdvanced:
"""Advanced tests for commit operations."""
@pytest.mark.asyncio
async def test_commit_specific_files(self, git_wrapper_with_repo, git_repo):
"""Test committing specific files only."""
# Create multiple files
file1 = Path(git_repo.working_dir) / "commit_specific1.txt"
file2 = Path(git_repo.working_dir) / "commit_specific2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
result = await git_wrapper_with_repo.commit(
"Commit specific file",
files=["commit_specific1.txt"],
)
assert result.success is True
assert result.files_changed == 1
@pytest.mark.asyncio
async def test_commit_with_partial_author(self, git_wrapper_with_repo, git_repo):
"""Test commit with only author name."""
new_file = Path(git_repo.working_dir) / "partial_author.txt"
new_file.write_text("content")
result = await git_wrapper_with_repo.commit(
"Partial author commit",
author_name="Test Author",
)
assert result.success is True
@pytest.mark.asyncio
async def test_commit_allow_empty(self, git_wrapper_with_repo):
"""Test allowing empty commits."""
result = await git_wrapper_with_repo.commit(
"Empty commit allowed",
allow_empty=True,
)
assert result.success is True
class TestGitWrapperUnstageAdvanced:
"""Advanced tests for unstaging operations."""
@pytest.mark.asyncio
async def test_unstage_specific_files(self, git_wrapper_with_repo, git_repo):
"""Test unstaging specific files."""
# Create and stage files
file1 = Path(git_repo.working_dir) / "unstage1.txt"
file2 = Path(git_repo.working_dir) / "unstage2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
git_repo.index.add(["unstage1.txt", "unstage2.txt"])
count = await git_wrapper_with_repo.unstage(["unstage1.txt"])
assert count == 1
class TestGitWrapperResetAdvanced:
"""Advanced tests for reset operations."""
@pytest.mark.asyncio
async def test_reset_hard(self, git_wrapper_with_repo, git_repo):
"""Test hard reset."""
# Create a commit
file1 = Path(git_repo.working_dir) / "reset_hard.txt"
file1.write_text("content")
git_repo.index.add(["reset_hard.txt"])
git_repo.index.commit("Commit for hard reset")
result = await git_wrapper_with_repo.reset("HEAD~1", mode="hard")
assert result is True
# File should be gone after hard reset
assert not file1.exists()
@pytest.mark.asyncio
async def test_reset_specific_files(self, git_wrapper_with_repo, git_repo):
"""Test resetting specific files."""
# Create and stage a file
file1 = Path(git_repo.working_dir) / "reset_file.txt"
file1.write_text("content")
git_repo.index.add(["reset_file.txt"])
result = await git_wrapper_with_repo.reset("HEAD", files=["reset_file.txt"])
assert result is True
class TestGitWrapperDiffAdvanced:
"""Advanced tests for diff operations."""
@pytest.mark.asyncio
async def test_diff_between_refs(self, git_wrapper_with_repo, git_repo):
"""Test diff between two refs."""
# Create initial commit
file1 = Path(git_repo.working_dir) / "diff_ref.txt"
file1.write_text("initial")
git_repo.index.add(["diff_ref.txt"])
commit1 = git_repo.index.commit("First commit for diff")
# Create second commit
file1.write_text("modified")
git_repo.index.add(["diff_ref.txt"])
commit2 = git_repo.index.commit("Second commit for diff")
result = await git_wrapper_with_repo.diff(
base=commit1.hexsha,
head=commit2.hexsha,
)
assert result.files_changed > 0
@pytest.mark.asyncio
async def test_diff_specific_files(self, git_wrapper_with_repo, git_repo):
"""Test diff for specific files only."""
# Create files
file1 = Path(git_repo.working_dir) / "diff_specific1.txt"
file2 = Path(git_repo.working_dir) / "diff_specific2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
result = await git_wrapper_with_repo.diff(files=["diff_specific1.txt"])
# Should only show changes for specified file
for f in result.files:
assert "diff_specific2.txt" not in f.get("path", "")
@pytest.mark.asyncio
async def test_diff_base_only(self, git_wrapper_with_repo, git_repo):
"""Test diff with base ref only (vs HEAD)."""
# Create commit
file1 = Path(git_repo.working_dir) / "diff_base.txt"
file1.write_text("content")
git_repo.index.add(["diff_base.txt"])
commit = git_repo.index.commit("Commit for diff base test")
# Get parent commit
parent = commit.parents[0] if commit.parents else commit
result = await git_wrapper_with_repo.diff(base=parent.hexsha)
assert isinstance(result.files_changed, int)
class TestGitWrapperLogAdvanced:
"""Advanced tests for log operations."""
@pytest.mark.asyncio
async def test_log_with_ref(self, git_wrapper_with_repo, git_repo):
"""Test log starting from specific ref."""
# Create branch with commits
await git_wrapper_with_repo.create_branch("log-test", checkout=True)
file1 = Path(git_repo.working_dir) / "log_ref.txt"
file1.write_text("content")
git_repo.index.add(["log_ref.txt"])
git_repo.index.commit("Commit on log-test branch")
result = await git_wrapper_with_repo.log(ref="log-test", limit=5)
assert result.total_commits > 0
@pytest.mark.asyncio
async def test_log_with_path(self, git_wrapper_with_repo, git_repo):
"""Test log filtered by path."""
# Create file and commit
file1 = Path(git_repo.working_dir) / "log_path.txt"
file1.write_text("content")
git_repo.index.add(["log_path.txt"])
git_repo.index.commit("Commit for path log")
result = await git_wrapper_with_repo.log(path="log_path.txt")
assert result.total_commits >= 1
@pytest.mark.asyncio
async def test_log_with_skip(self, git_wrapper_with_repo, git_repo):
"""Test log with skip parameter."""
# Create multiple commits
for i in range(3):
file_path = Path(git_repo.working_dir) / f"skip_test{i}.txt"
file_path.write_text(f"content {i}")
git_repo.index.add([f"skip_test{i}.txt"])
git_repo.index.commit(f"Skip test commit {i}")
result = await git_wrapper_with_repo.log(skip=1, limit=2)
# Should have skipped first commit
assert len(result.commits) <= 2
class TestGitWrapperRemoteUrl:
"""Tests for remote URL operations."""
@pytest.mark.asyncio
async def test_get_remote_url_nonexistent(self, git_wrapper_with_repo):
"""Test getting URL for non-existent remote."""
url = await git_wrapper_with_repo.get_remote_url("nonexistent")
assert url is None
class TestGitWrapperConfig:
"""Tests for git config operations."""
@pytest.mark.asyncio
async def test_set_and_get_config(self, git_wrapper_with_repo):
"""Test setting and getting config value."""
await git_wrapper_with_repo.set_config("test.key", "test_value")
value = await git_wrapper_with_repo.get_config("test.key")
assert value == "test_value"
@pytest.mark.asyncio
async def test_get_config_nonexistent(self, git_wrapper_with_repo):
"""Test getting non-existent config value."""
value = await git_wrapper_with_repo.get_config("nonexistent.key")
assert value is None
class TestGitWrapperClone:
"""Tests for clone operations."""
@pytest.mark.asyncio
async def test_clone_success(self, temp_workspace, test_settings):
"""Test successful clone."""
wrapper = GitWrapper(temp_workspace, test_settings)
# Mock the clone operation
with patch("git_wrapper.GitRepo") as mock_repo_class:
mock_repo = MagicMock()
mock_repo.active_branch.name = "main"
mock_repo.head.commit.hexsha = "abc123"
mock_repo_class.clone_from.return_value = mock_repo
result = await wrapper.clone("https://github.com/test/repo.git")
assert result.success is True
assert result.branch == "main"
assert result.commit_sha == "abc123"
@pytest.mark.asyncio
async def test_clone_with_auth_token(self, temp_workspace, test_settings):
"""Test clone with auth token."""
wrapper = GitWrapper(temp_workspace, test_settings)
with patch("git_wrapper.GitRepo") as mock_repo_class:
mock_repo = MagicMock()
mock_repo.active_branch.name = "main"
mock_repo.head.commit.hexsha = "abc123"
mock_repo_class.clone_from.return_value = mock_repo
result = await wrapper.clone(
"https://github.com/test/repo.git",
auth_token="test-token",
)
assert result.success is True
# Verify token was injected in URL
call_args = mock_repo_class.clone_from.call_args
assert "test-token@" in call_args.kwargs["url"]
@pytest.mark.asyncio
async def test_clone_with_branch_and_depth(self, temp_workspace, test_settings):
"""Test clone with branch and depth parameters."""
wrapper = GitWrapper(temp_workspace, test_settings)
with patch("git_wrapper.GitRepo") as mock_repo_class:
mock_repo = MagicMock()
mock_repo.active_branch.name = "develop"
mock_repo.head.commit.hexsha = "def456"
mock_repo_class.clone_from.return_value = mock_repo
result = await wrapper.clone(
"https://github.com/test/repo.git",
branch="develop",
depth=1,
)
assert result.success is True
call_args = mock_repo_class.clone_from.call_args
assert call_args.kwargs["branch"] == "develop"
assert call_args.kwargs["depth"] == 1
@pytest.mark.asyncio
async def test_clone_failure(self, temp_workspace, test_settings):
"""Test clone failure raises CloneError."""
wrapper = GitWrapper(temp_workspace, test_settings)
with patch("git_wrapper.GitRepo") as mock_repo_class:
mock_repo_class.clone_from.side_effect = GitCommandError(
"git clone", 128, stderr="Authentication failed"
)
with pytest.raises(CloneError):
await wrapper.clone("https://github.com/test/repo.git")
class TestGitWrapperPush:
"""Tests for push operations."""
@pytest.mark.asyncio
async def test_push_force_disabled(self, git_wrapper_with_repo, git_repo):
"""Test force push is disabled by default."""
git_repo.create_remote("origin", "https://github.com/test/repo.git")
with pytest.raises(PushError, match="Force push is disabled"):
await git_wrapper_with_repo.push(force=True)
@pytest.mark.asyncio
async def test_push_remote_not_found(self, git_wrapper_with_repo):
"""Test push to non-existent remote."""
with pytest.raises(PushError, match="Remote not found"):
await git_wrapper_with_repo.push(remote="nonexistent")
class TestGitWrapperPull:
"""Tests for pull operations."""
@pytest.mark.asyncio
async def test_pull_remote_not_found(self, git_wrapper_with_repo):
"""Test pull from non-existent remote."""
with pytest.raises(PullError, match="Remote not found"):
await git_wrapper_with_repo.pull(remote="nonexistent")
class TestGitWrapperFetch:
"""Tests for fetch operations."""
@pytest.mark.asyncio
async def test_fetch_remote_not_found(self, git_wrapper_with_repo):
"""Test fetch from non-existent remote."""
with pytest.raises(GitError, match="Remote not found"):
await git_wrapper_with_repo.fetch(remote="nonexistent")
class TestGitWrapperDiffHeadOnly:
"""Tests for diff with head ref only."""
@pytest.mark.asyncio
async def test_diff_head_only(self, git_wrapper_with_repo, git_repo):
"""Test diff with head ref only (working tree vs ref)."""
# Make some changes
readme = Path(git_repo.working_dir) / "README.md"
readme.write_text("modified content")
# This tests the head-only branch (base=None, head=specified)
result = await git_wrapper_with_repo.diff(head="HEAD")
assert isinstance(result.files_changed, int)
class TestGitWrapperRemoteWithUrl:
"""Tests for getting remote URL when remote exists."""
@pytest.mark.asyncio
async def test_get_remote_url_exists(self, git_wrapper_with_repo, git_repo):
"""Test getting URL for existing remote."""
git_repo.create_remote("origin", "https://github.com/test/repo.git")
url = await git_wrapper_with_repo.get_remote_url("origin")
assert url == "https://github.com/test/repo.git"
class TestRunInExecutor:
"""Tests for run_in_executor utility."""
@pytest.mark.asyncio
async def test_run_in_executor(self):
"""Test running function in executor."""
def blocking_func(x, y):
return x + y
result = await run_in_executor(blocking_func, 1, 2)
assert result == 3

View File

@@ -0,0 +1,620 @@
"""
Tests for GitHub provider implementation.
"""
from unittest.mock import MagicMock
import pytest
from exceptions import APIError, AuthenticationError
from models import MergeStrategy, PRState
from providers.github import GitHubProvider
class TestGitHubProviderBasics:
"""Tests for GitHubProvider basic operations."""
def test_provider_name(self):
"""Test provider name is github."""
provider = GitHubProvider(token="test-token")
assert provider.name == "github"
def test_parse_repo_url_https(self):
"""Test parsing HTTPS repo URL."""
provider = GitHubProvider(token="test-token")
owner, repo = provider.parse_repo_url("https://github.com/owner/repo.git")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_https_no_git(self):
"""Test parsing HTTPS URL without .git suffix."""
provider = GitHubProvider(token="test-token")
owner, repo = provider.parse_repo_url("https://github.com/owner/repo")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_ssh(self):
"""Test parsing SSH repo URL."""
provider = GitHubProvider(token="test-token")
owner, repo = provider.parse_repo_url("git@github.com:owner/repo.git")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_invalid(self):
"""Test error on invalid URL."""
provider = GitHubProvider(token="test-token")
with pytest.raises(ValueError, match="Unable to parse"):
provider.parse_repo_url("invalid-url")
@pytest.fixture
def mock_github_httpx_client():
"""Create a mock httpx client for GitHub provider tests."""
from unittest.mock import AsyncMock
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = MagicMock(return_value={})
mock_response.text = ""
mock_client = AsyncMock()
mock_client.request = AsyncMock(return_value=mock_response)
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.patch = AsyncMock(return_value=mock_response)
mock_client.put = AsyncMock(return_value=mock_response)
mock_client.delete = AsyncMock(return_value=mock_response)
return mock_client
@pytest.fixture
async def github_provider(test_settings, mock_github_httpx_client):
"""Create a GitHubProvider with mocked HTTP client."""
provider = GitHubProvider(
token=test_settings.github_token,
settings=test_settings,
)
provider._client = mock_github_httpx_client
yield provider
await provider.close()
@pytest.fixture
def github_pr_data():
"""Sample PR data from GitHub API."""
return {
"number": 42,
"title": "Test PR",
"body": "This is a test pull request",
"state": "open",
"head": {"ref": "feature-branch"},
"base": {"ref": "main"},
"user": {"login": "test-user"},
"created_at": "2024-01-15T10:00:00Z",
"updated_at": "2024-01-15T12:00:00Z",
"merged_at": None,
"closed_at": None,
"html_url": "https://github.com/owner/repo/pull/42",
"labels": [{"name": "enhancement"}],
"assignees": [{"login": "assignee1"}],
"requested_reviewers": [{"login": "reviewer1"}],
"mergeable": True,
"draft": False,
}
class TestGitHubProviderConnection:
"""Tests for GitHub provider connection."""
@pytest.mark.asyncio
async def test_is_connected(self, github_provider, mock_github_httpx_client):
"""Test connection check."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"login": "test-user"}
)
result = await github_provider.is_connected()
assert result is True
@pytest.mark.asyncio
async def test_is_connected_no_token(self, test_settings):
"""Test connection fails without token."""
provider = GitHubProvider(
token="",
settings=test_settings,
)
result = await provider.is_connected()
assert result is False
await provider.close()
@pytest.mark.asyncio
async def test_get_authenticated_user(
self, github_provider, mock_github_httpx_client
):
"""Test getting authenticated user."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"login": "test-user"}
)
user = await github_provider.get_authenticated_user()
assert user == "test-user"
class TestGitHubProviderRepoOperations:
"""Tests for GitHub repository operations."""
@pytest.mark.asyncio
async def test_get_repo_info(self, github_provider, mock_github_httpx_client):
"""Test getting repository info."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"name": "repo",
"full_name": "owner/repo",
"default_branch": "main",
}
)
result = await github_provider.get_repo_info("owner", "repo")
assert result["name"] == "repo"
assert result["default_branch"] == "main"
@pytest.mark.asyncio
async def test_get_default_branch(self, github_provider, mock_github_httpx_client):
"""Test getting default branch."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"default_branch": "develop"}
)
branch = await github_provider.get_default_branch("owner", "repo")
assert branch == "develop"
class TestGitHubPROperations:
"""Tests for GitHub PR operations."""
@pytest.mark.asyncio
async def test_create_pr(self, github_provider, mock_github_httpx_client):
"""Test creating a pull request."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"number": 42,
"html_url": "https://github.com/owner/repo/pull/42",
}
)
result = await github_provider.create_pr(
owner="owner",
repo="repo",
title="Test PR",
body="Test body",
source_branch="feature",
target_branch="main",
)
assert result.success is True
assert result.pr_number == 42
assert result.pr_url == "https://github.com/owner/repo/pull/42"
@pytest.mark.asyncio
async def test_create_pr_with_draft(
self, github_provider, mock_github_httpx_client
):
"""Test creating a draft PR."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"number": 43,
"html_url": "https://github.com/owner/repo/pull/43",
}
)
result = await github_provider.create_pr(
owner="owner",
repo="repo",
title="Draft PR",
body="Draft body",
source_branch="feature",
target_branch="main",
draft=True,
)
assert result.success is True
assert result.pr_number == 43
@pytest.mark.asyncio
async def test_create_pr_with_options(
self, github_provider, mock_github_httpx_client
):
"""Test creating PR with labels, assignees, reviewers."""
mock_responses = [
{
"number": 44,
"html_url": "https://github.com/owner/repo/pull/44",
}, # Create PR
[{"name": "enhancement"}], # POST add labels
{}, # POST add assignees
{}, # POST request reviewers
]
mock_github_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await github_provider.create_pr(
owner="owner",
repo="repo",
title="Test PR",
body="Test body",
source_branch="feature",
target_branch="main",
labels=["enhancement"],
assignees=["user1"],
reviewers=["reviewer1"],
)
assert result.success is True
@pytest.mark.asyncio
async def test_get_pr(
self, github_provider, mock_github_httpx_client, github_pr_data
):
"""Test getting a pull request."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=github_pr_data
)
result = await github_provider.get_pr("owner", "repo", 42)
assert result.success is True
assert result.pr["number"] == 42
assert result.pr["title"] == "Test PR"
@pytest.mark.asyncio
async def test_get_pr_not_found(self, github_provider, mock_github_httpx_client):
"""Test getting non-existent PR."""
mock_github_httpx_client.request.return_value.status_code = 404
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=None
)
result = await github_provider.get_pr("owner", "repo", 999)
assert result.success is False
@pytest.mark.asyncio
async def test_list_prs(
self, github_provider, mock_github_httpx_client, github_pr_data
):
"""Test listing pull requests."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[github_pr_data, github_pr_data]
)
result = await github_provider.list_prs("owner", "repo")
assert result.success is True
assert len(result.pull_requests) == 2
@pytest.mark.asyncio
async def test_list_prs_with_state_filter(
self, github_provider, mock_github_httpx_client, github_pr_data
):
"""Test listing PRs with state filter."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[github_pr_data]
)
result = await github_provider.list_prs("owner", "repo", state=PRState.OPEN)
assert result.success is True
@pytest.mark.asyncio
async def test_merge_pr(
self, github_provider, mock_github_httpx_client, github_pr_data
):
"""Test merging a pull request."""
# Merge returns sha, then get_pr returns the PR data, then delete branch
mock_responses = [
{"sha": "merge-commit-sha", "merged": True}, # PUT merge
github_pr_data, # GET PR for branch info
None, # DELETE branch
]
mock_github_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await github_provider.merge_pr(
"owner",
"repo",
42,
merge_strategy=MergeStrategy.SQUASH,
)
assert result.success is True
assert result.merge_commit_sha == "merge-commit-sha"
@pytest.mark.asyncio
async def test_merge_pr_rebase(
self, github_provider, mock_github_httpx_client, github_pr_data
):
"""Test merging with rebase strategy."""
mock_responses = [
{"sha": "rebase-commit-sha", "merged": True}, # PUT merge
github_pr_data, # GET PR for branch info
None, # DELETE branch
]
mock_github_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await github_provider.merge_pr(
"owner",
"repo",
42,
merge_strategy=MergeStrategy.REBASE,
)
assert result.success is True
@pytest.mark.asyncio
async def test_update_pr(
self, github_provider, mock_github_httpx_client, github_pr_data
):
"""Test updating a pull request."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=github_pr_data
)
result = await github_provider.update_pr(
"owner",
"repo",
42,
title="Updated Title",
body="Updated body",
)
assert result.success is True
@pytest.mark.asyncio
async def test_close_pr(
self, github_provider, mock_github_httpx_client, github_pr_data
):
"""Test closing a pull request."""
github_pr_data["state"] = "closed"
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=github_pr_data
)
result = await github_provider.close_pr("owner", "repo", 42)
assert result.success is True
class TestGitHubBranchOperations:
"""Tests for GitHub branch operations."""
@pytest.mark.asyncio
async def test_get_branch(self, github_provider, mock_github_httpx_client):
"""Test getting branch info."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"name": "main",
"commit": {"sha": "abc123"},
}
)
result = await github_provider.get_branch("owner", "repo", "main")
assert result["name"] == "main"
@pytest.mark.asyncio
async def test_delete_remote_branch(
self, github_provider, mock_github_httpx_client
):
"""Test deleting a remote branch."""
mock_github_httpx_client.request.return_value.status_code = 204
result = await github_provider.delete_remote_branch(
"owner", "repo", "old-branch"
)
assert result is True
class TestGitHubCommentOperations:
"""Tests for GitHub comment operations."""
@pytest.mark.asyncio
async def test_add_pr_comment(self, github_provider, mock_github_httpx_client):
"""Test adding a comment to a PR."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"id": 1, "body": "Test comment"}
)
result = await github_provider.add_pr_comment(
"owner", "repo", 42, "Test comment"
)
assert result["body"] == "Test comment"
@pytest.mark.asyncio
async def test_list_pr_comments(self, github_provider, mock_github_httpx_client):
"""Test listing PR comments."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[
{"id": 1, "body": "Comment 1"},
{"id": 2, "body": "Comment 2"},
]
)
result = await github_provider.list_pr_comments("owner", "repo", 42)
assert len(result) == 2
class TestGitHubLabelOperations:
"""Tests for GitHub label operations."""
@pytest.mark.asyncio
async def test_add_labels(self, github_provider, mock_github_httpx_client):
"""Test adding labels to a PR."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[{"name": "bug"}, {"name": "urgent"}]
)
result = await github_provider.add_labels(
"owner", "repo", 42, ["bug", "urgent"]
)
assert "bug" in result
assert "urgent" in result
@pytest.mark.asyncio
async def test_remove_label(self, github_provider, mock_github_httpx_client):
"""Test removing a label from a PR."""
mock_responses = [
None, # DELETE label
{"labels": []}, # GET issue
]
mock_github_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await github_provider.remove_label("owner", "repo", 42, "bug")
assert isinstance(result, list)
class TestGitHubReviewerOperations:
"""Tests for GitHub reviewer operations."""
@pytest.mark.asyncio
async def test_request_review(self, github_provider, mock_github_httpx_client):
"""Test requesting review from users."""
mock_github_httpx_client.request.return_value.json = MagicMock(return_value={})
result = await github_provider.request_review(
"owner", "repo", 42, ["reviewer1", "reviewer2"]
)
assert result == ["reviewer1", "reviewer2"]
class TestGitHubErrorHandling:
"""Tests for error handling in GitHub provider."""
@pytest.mark.asyncio
async def test_authentication_error(
self, github_provider, mock_github_httpx_client
):
"""Test handling authentication errors."""
mock_github_httpx_client.request.return_value.status_code = 401
with pytest.raises(AuthenticationError):
await github_provider._request("GET", "/user")
@pytest.mark.asyncio
async def test_permission_denied(self, github_provider, mock_github_httpx_client):
"""Test handling permission denied errors."""
mock_github_httpx_client.request.return_value.status_code = 403
mock_github_httpx_client.request.return_value.text = "Permission denied"
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
await github_provider._request("GET", "/protected")
@pytest.mark.asyncio
async def test_rate_limit_error(self, github_provider, mock_github_httpx_client):
"""Test handling rate limit errors."""
mock_github_httpx_client.request.return_value.status_code = 403
mock_github_httpx_client.request.return_value.text = "API rate limit exceeded"
with pytest.raises(APIError, match="rate limit"):
await github_provider._request("GET", "/user")
@pytest.mark.asyncio
async def test_api_error(self, github_provider, mock_github_httpx_client):
"""Test handling general API errors."""
mock_github_httpx_client.request.return_value.status_code = 500
mock_github_httpx_client.request.return_value.text = "Internal Server Error"
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"message": "Server error"}
)
with pytest.raises(APIError):
await github_provider._request("GET", "/error")
class TestGitHubPRParsing:
"""Tests for PR data parsing."""
def test_parse_pr_open(self, github_provider, github_pr_data):
"""Test parsing open PR."""
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.number == 42
assert pr_info.state == PRState.OPEN
assert pr_info.title == "Test PR"
assert pr_info.source_branch == "feature-branch"
assert pr_info.target_branch == "main"
def test_parse_pr_merged(self, github_provider, github_pr_data):
"""Test parsing merged PR."""
github_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.state == PRState.MERGED
def test_parse_pr_closed(self, github_provider, github_pr_data):
"""Test parsing closed PR."""
github_pr_data["state"] = "closed"
github_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.state == PRState.CLOSED
def test_parse_pr_draft(self, github_provider, github_pr_data):
"""Test parsing draft PR."""
github_pr_data["draft"] = True
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.draft is True
def test_parse_datetime_iso(self, github_provider):
"""Test parsing ISO datetime strings."""
dt = github_provider._parse_datetime("2024-01-15T10:30:00Z")
assert dt.year == 2024
assert dt.month == 1
assert dt.day == 15
def test_parse_datetime_none(self, github_provider):
"""Test parsing None datetime returns now."""
dt = github_provider._parse_datetime(None)
assert dt is not None
assert dt.tzinfo is not None
def test_parse_pr_with_null_body(self, github_provider, github_pr_data):
"""Test parsing PR with null body."""
github_pr_data["body"] = None
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.body == ""

View File

@@ -0,0 +1,486 @@
"""
Tests for git provider implementations.
"""
from unittest.mock import MagicMock
import pytest
from exceptions import APIError, AuthenticationError
from models import MergeStrategy, PRState
from providers.gitea import GiteaProvider
class TestBaseProvider:
"""Tests for BaseProvider interface."""
def test_parse_repo_url_https(self, mock_gitea_provider):
"""Test parsing HTTPS repo URL."""
# The mock needs parse_repo_url to work
provider = GiteaProvider(base_url="https://gitea.test.com", token="test-token")
owner, repo = provider.parse_repo_url("https://gitea.test.com/owner/repo.git")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_https_no_git(self):
"""Test parsing HTTPS URL without .git suffix."""
provider = GiteaProvider(base_url="https://gitea.test.com", token="test-token")
owner, repo = provider.parse_repo_url("https://gitea.test.com/owner/repo")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_ssh(self):
"""Test parsing SSH repo URL."""
provider = GiteaProvider(base_url="https://gitea.test.com", token="test-token")
owner, repo = provider.parse_repo_url("git@gitea.test.com:owner/repo.git")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_invalid(self):
"""Test error on invalid URL."""
provider = GiteaProvider(base_url="https://gitea.test.com", token="test-token")
with pytest.raises(ValueError, match="Unable to parse"):
provider.parse_repo_url("invalid-url")
class TestGiteaProvider:
"""Tests for GiteaProvider."""
@pytest.mark.asyncio
async def test_is_connected(self, gitea_provider, mock_httpx_client):
"""Test connection check."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"login": "test-user"}
)
result = await gitea_provider.is_connected()
assert result is True
@pytest.mark.asyncio
async def test_is_connected_no_token(self, test_settings):
"""Test connection fails without token."""
provider = GiteaProvider(
base_url="https://gitea.test.com",
token="",
settings=test_settings,
)
result = await provider.is_connected()
assert result is False
await provider.close()
@pytest.mark.asyncio
async def test_get_authenticated_user(self, gitea_provider, mock_httpx_client):
"""Test getting authenticated user."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"login": "test-user"}
)
user = await gitea_provider.get_authenticated_user()
assert user == "test-user"
@pytest.mark.asyncio
async def test_get_repo_info(self, gitea_provider, mock_httpx_client):
"""Test getting repository info."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={
"name": "repo",
"full_name": "owner/repo",
"default_branch": "main",
}
)
result = await gitea_provider.get_repo_info("owner", "repo")
assert result["name"] == "repo"
assert result["default_branch"] == "main"
@pytest.mark.asyncio
async def test_get_default_branch(self, gitea_provider, mock_httpx_client):
"""Test getting default branch."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"default_branch": "develop"}
)
branch = await gitea_provider.get_default_branch("owner", "repo")
assert branch == "develop"
class TestGiteaPROperations:
"""Tests for Gitea PR operations."""
@pytest.mark.asyncio
async def test_create_pr(self, gitea_provider, mock_httpx_client):
"""Test creating a pull request."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={
"number": 42,
"html_url": "https://gitea.test.com/owner/repo/pull/42",
}
)
result = await gitea_provider.create_pr(
owner="owner",
repo="repo",
title="Test PR",
body="Test body",
source_branch="feature",
target_branch="main",
)
assert result.success is True
assert result.pr_number == 42
assert result.pr_url == "https://gitea.test.com/owner/repo/pull/42"
@pytest.mark.asyncio
async def test_create_pr_with_options(self, gitea_provider, mock_httpx_client):
"""Test creating PR with labels, assignees, reviewers."""
# Use side_effect for multiple API calls:
# 1. POST create PR
# 2. GET labels (for "enhancement") - in add_labels -> _get_or_create_label
# 3. POST add labels to PR - in add_labels
# 4. GET issue to return labels - in add_labels
# 5. PATCH add assignees
# 6. POST request reviewers
mock_responses = [
{
"number": 43,
"html_url": "https://gitea.test.com/owner/repo/pull/43",
}, # Create PR
[{"id": 1, "name": "enhancement"}], # GET labels (found)
{}, # POST add labels to PR
{"labels": [{"name": "enhancement"}]}, # GET issue to return current labels
{}, # PATCH add assignees
{}, # POST request reviewers
]
mock_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await gitea_provider.create_pr(
owner="owner",
repo="repo",
title="Test PR",
body="Test body",
source_branch="feature",
target_branch="main",
labels=["enhancement"],
assignees=["user1"],
reviewers=["reviewer1"],
)
assert result.success is True
@pytest.mark.asyncio
async def test_get_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
"""Test getting a pull request."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value=sample_pr_data
)
result = await gitea_provider.get_pr("owner", "repo", 42)
assert result.success is True
assert result.pr["number"] == 42
assert result.pr["title"] == "Test PR"
@pytest.mark.asyncio
async def test_get_pr_not_found(self, gitea_provider, mock_httpx_client):
"""Test getting non-existent PR."""
mock_httpx_client.request.return_value.status_code = 404
mock_httpx_client.request.return_value.json = MagicMock(return_value=None)
result = await gitea_provider.get_pr("owner", "repo", 999)
assert result.success is False
@pytest.mark.asyncio
async def test_list_prs(self, gitea_provider, mock_httpx_client, sample_pr_data):
"""Test listing pull requests."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value=[sample_pr_data, sample_pr_data]
)
result = await gitea_provider.list_prs("owner", "repo")
assert result.success is True
assert len(result.pull_requests) == 2
@pytest.mark.asyncio
async def test_list_prs_with_state_filter(
self, gitea_provider, mock_httpx_client, sample_pr_data
):
"""Test listing PRs with state filter."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value=[sample_pr_data]
)
result = await gitea_provider.list_prs("owner", "repo", state=PRState.OPEN)
assert result.success is True
@pytest.mark.asyncio
async def test_merge_pr(self, gitea_provider, mock_httpx_client):
"""Test merging a pull request."""
# First call returns merge result
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"sha": "merge-commit-sha"}
)
result = await gitea_provider.merge_pr(
"owner",
"repo",
42,
merge_strategy=MergeStrategy.SQUASH,
)
assert result.success is True
assert result.merge_commit_sha == "merge-commit-sha"
@pytest.mark.asyncio
async def test_update_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
"""Test updating a pull request."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value=sample_pr_data
)
result = await gitea_provider.update_pr(
"owner",
"repo",
42,
title="Updated Title",
body="Updated body",
)
assert result.success is True
@pytest.mark.asyncio
async def test_close_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
"""Test closing a pull request."""
sample_pr_data["state"] = "closed"
mock_httpx_client.request.return_value.json = MagicMock(
return_value=sample_pr_data
)
result = await gitea_provider.close_pr("owner", "repo", 42)
assert result.success is True
class TestGiteaBranchOperations:
"""Tests for Gitea branch operations."""
@pytest.mark.asyncio
async def test_get_branch(self, gitea_provider, mock_httpx_client):
"""Test getting branch info."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={
"name": "main",
"commit": {"sha": "abc123"},
}
)
result = await gitea_provider.get_branch("owner", "repo", "main")
assert result["name"] == "main"
@pytest.mark.asyncio
async def test_delete_remote_branch(self, gitea_provider, mock_httpx_client):
"""Test deleting a remote branch."""
mock_httpx_client.request.return_value.status_code = 204
result = await gitea_provider.delete_remote_branch(
"owner", "repo", "old-branch"
)
assert result is True
class TestGiteaCommentOperations:
"""Tests for Gitea comment operations."""
@pytest.mark.asyncio
async def test_add_pr_comment(self, gitea_provider, mock_httpx_client):
"""Test adding a comment to a PR."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"id": 1, "body": "Test comment"}
)
result = await gitea_provider.add_pr_comment(
"owner", "repo", 42, "Test comment"
)
assert result["body"] == "Test comment"
@pytest.mark.asyncio
async def test_list_pr_comments(self, gitea_provider, mock_httpx_client):
"""Test listing PR comments."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value=[
{"id": 1, "body": "Comment 1"},
{"id": 2, "body": "Comment 2"},
]
)
result = await gitea_provider.list_pr_comments("owner", "repo", 42)
assert len(result) == 2
class TestGiteaLabelOperations:
"""Tests for Gitea label operations."""
@pytest.mark.asyncio
async def test_add_labels(self, gitea_provider, mock_httpx_client):
"""Test adding labels to a PR."""
# Use side_effect to return different values for different calls
# 1. GET labels (for "bug") - returns existing labels
# 2. POST to create "bug" label
# 3. GET labels (for "urgent")
# 4. POST to create "urgent" label
# 5. POST labels to PR
# 6. GET issue to return final labels
mock_responses = [
[{"id": 1, "name": "existing"}], # GET labels (bug not found)
{"id": 2, "name": "bug"}, # POST create bug
[
{"id": 1, "name": "existing"},
{"id": 2, "name": "bug"},
], # GET labels (urgent not found)
{"id": 3, "name": "urgent"}, # POST create urgent
{}, # POST add labels to PR
{"labels": [{"name": "bug"}, {"name": "urgent"}]}, # GET issue
]
mock_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await gitea_provider.add_labels("owner", "repo", 42, ["bug", "urgent"])
# Should return updated label list
assert isinstance(result, list)
@pytest.mark.asyncio
async def test_remove_label(self, gitea_provider, mock_httpx_client):
"""Test removing a label from a PR."""
# Use side_effect for multiple calls
# 1. GET labels to find the label ID
# 2. DELETE the label from the PR
# 3. GET issue to return remaining labels
mock_responses = [
[{"id": 1, "name": "bug"}], # GET labels
{}, # DELETE label
{"labels": []}, # GET issue
]
mock_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await gitea_provider.remove_label("owner", "repo", 42, "bug")
assert isinstance(result, list)
class TestGiteaReviewerOperations:
"""Tests for Gitea reviewer operations."""
@pytest.mark.asyncio
async def test_request_review(self, gitea_provider, mock_httpx_client):
"""Test requesting review from users."""
mock_httpx_client.request.return_value.json = MagicMock(return_value={})
result = await gitea_provider.request_review(
"owner", "repo", 42, ["reviewer1", "reviewer2"]
)
assert result == ["reviewer1", "reviewer2"]
class TestGiteaErrorHandling:
"""Tests for error handling in Gitea provider."""
@pytest.mark.asyncio
async def test_authentication_error(self, gitea_provider, mock_httpx_client):
"""Test handling authentication errors."""
mock_httpx_client.request.return_value.status_code = 401
with pytest.raises(AuthenticationError):
await gitea_provider._request("GET", "/user")
@pytest.mark.asyncio
async def test_permission_denied(self, gitea_provider, mock_httpx_client):
"""Test handling permission denied errors."""
mock_httpx_client.request.return_value.status_code = 403
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
await gitea_provider._request("GET", "/protected")
@pytest.mark.asyncio
async def test_api_error(self, gitea_provider, mock_httpx_client):
"""Test handling general API errors."""
mock_httpx_client.request.return_value.status_code = 500
mock_httpx_client.request.return_value.text = "Internal Server Error"
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"message": "Server error"}
)
with pytest.raises(APIError):
await gitea_provider._request("GET", "/error")
class TestGiteaPRParsing:
"""Tests for PR data parsing."""
def test_parse_pr_open(self, gitea_provider, sample_pr_data):
"""Test parsing open PR."""
pr_info = gitea_provider._parse_pr(sample_pr_data)
assert pr_info.number == 42
assert pr_info.state == PRState.OPEN
assert pr_info.title == "Test PR"
assert pr_info.source_branch == "feature-branch"
assert pr_info.target_branch == "main"
def test_parse_pr_merged(self, gitea_provider, sample_pr_data):
"""Test parsing merged PR."""
sample_pr_data["merged"] = True
sample_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
pr_info = gitea_provider._parse_pr(sample_pr_data)
assert pr_info.state == PRState.MERGED
def test_parse_pr_closed(self, gitea_provider, sample_pr_data):
"""Test parsing closed PR."""
sample_pr_data["state"] = "closed"
sample_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
pr_info = gitea_provider._parse_pr(sample_pr_data)
assert pr_info.state == PRState.CLOSED
def test_parse_datetime_iso(self, gitea_provider):
"""Test parsing ISO datetime strings."""
dt = gitea_provider._parse_datetime("2024-01-15T10:30:00Z")
assert dt.year == 2024
assert dt.month == 1
assert dt.day == 15
def test_parse_datetime_none(self, gitea_provider):
"""Test parsing None datetime returns now."""
dt = gitea_provider._parse_datetime(None)
assert dt is not None
assert dt.tzinfo is not None

View File

@@ -0,0 +1,522 @@
"""
Tests for the MCP server and tools.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from exceptions import ErrorCode
class TestInputValidation:
"""Tests for input validation functions."""
def test_validate_id_valid(self):
"""Test valid IDs pass validation."""
from server import _validate_id
assert _validate_id("test-123", "project_id") is None
assert _validate_id("my_project", "project_id") is None
assert _validate_id("Agent-001", "agent_id") is None
def test_validate_id_empty(self):
"""Test empty ID fails validation."""
from server import _validate_id
error = _validate_id("", "project_id")
assert error is not None
assert "required" in error.lower()
def test_validate_id_too_long(self):
"""Test too-long ID fails validation."""
from server import _validate_id
error = _validate_id("a" * 200, "project_id")
assert error is not None
assert "1-128" in error
def test_validate_id_invalid_chars(self):
"""Test invalid characters fail validation."""
from server import _validate_id
assert _validate_id("test@invalid", "project_id") is not None
assert _validate_id("test!project", "project_id") is not None
assert _validate_id("test project", "project_id") is not None
def test_validate_branch_valid(self):
"""Test valid branch names."""
from server import _validate_branch
assert _validate_branch("main") is None
assert _validate_branch("feature/new-thing") is None
assert _validate_branch("release-1.0.0") is None
assert _validate_branch("hotfix.urgent") is None
def test_validate_branch_invalid(self):
"""Test invalid branch names."""
from server import _validate_branch
assert _validate_branch("") is not None
assert _validate_branch("a" * 300) is not None
def test_validate_url_valid(self):
"""Test valid repository URLs."""
from server import _validate_url
assert _validate_url("https://github.com/owner/repo.git") is None
assert _validate_url("https://gitea.example.com/owner/repo") is None
assert _validate_url("git@github.com:owner/repo.git") is None
def test_validate_url_invalid(self):
"""Test invalid repository URLs."""
from server import _validate_url
assert _validate_url("") is not None
assert _validate_url("not-a-url") is not None
assert _validate_url("ftp://invalid.com/repo") is not None
class TestHealthCheck:
"""Tests for health check endpoint."""
@pytest.mark.asyncio
async def test_health_check_structure(self):
"""Test health check returns proper structure."""
from server import health_check
with (
patch("server._gitea_provider", None),
patch("server._workspace_manager", None),
):
result = await health_check()
assert "status" in result
assert "service" in result
assert "version" in result
assert "timestamp" in result
assert "dependencies" in result
@pytest.mark.asyncio
async def test_health_check_no_providers(self):
"""Test health check without providers configured."""
from server import health_check
with (
patch("server._gitea_provider", None),
patch("server._workspace_manager", None),
):
result = await health_check()
assert result["dependencies"]["gitea"] == "not configured"
class TestToolRegistry:
"""Tests for tool registration."""
def test_tool_registry_populated(self):
"""Test that tools are registered."""
from server import _tool_registry
assert len(_tool_registry) > 0
assert "clone_repository" in _tool_registry
assert "git_status" in _tool_registry
assert "create_branch" in _tool_registry
assert "commit" in _tool_registry
def test_tool_schema_structure(self):
"""Test tool schemas have proper structure."""
from server import _tool_registry
for name, info in _tool_registry.items():
assert "func" in info
assert "description" in info
assert "schema" in info
assert info["schema"]["type"] == "object"
assert "properties" in info["schema"]
class TestCloneRepository:
"""Tests for clone_repository tool."""
@pytest.mark.asyncio
async def test_clone_invalid_project_id(self):
"""Test clone with invalid project ID."""
from server import clone_repository
# Access the underlying function via .fn
result = await clone_repository.fn(
project_id="invalid@id",
agent_id="agent-1",
repo_url="https://github.com/owner/repo.git",
)
assert result["success"] is False
assert "project_id" in result["error"].lower()
@pytest.mark.asyncio
async def test_clone_invalid_repo_url(self):
"""Test clone with invalid repo URL."""
from server import clone_repository
result = await clone_repository.fn(
project_id="valid-project",
agent_id="agent-1",
repo_url="not-a-valid-url",
)
assert result["success"] is False
assert "url" in result["error"].lower()
class TestGitStatus:
"""Tests for git_status tool."""
@pytest.mark.asyncio
async def test_status_workspace_not_found(self):
"""Test status when workspace doesn't exist."""
from server import git_status
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await git_status.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
assert result["code"] == ErrorCode.WORKSPACE_NOT_FOUND.value
class TestBranchOperations:
"""Tests for branch operation tools."""
@pytest.mark.asyncio
async def test_create_branch_invalid_name(self):
"""Test creating branch with invalid name."""
from server import create_branch
result = await create_branch.fn(
project_id="test-project",
agent_id="agent-1",
branch_name="", # Invalid
)
assert result["success"] is False
@pytest.mark.asyncio
async def test_list_branches_workspace_not_found(self):
"""Test listing branches when workspace doesn't exist."""
from server import list_branches
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await list_branches.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
@pytest.mark.asyncio
async def test_checkout_invalid_project(self):
"""Test checkout with invalid project ID."""
from server import checkout
result = await checkout.fn(
project_id="inv@lid",
agent_id="agent-1",
ref="main",
)
assert result["success"] is False
class TestCommitOperations:
"""Tests for commit operation tools."""
@pytest.mark.asyncio
async def test_commit_invalid_project(self):
"""Test commit with invalid project ID."""
from server import commit
result = await commit.fn(
project_id="inv@lid",
agent_id="agent-1",
message="Test commit",
)
assert result["success"] is False
class TestPushPullOperations:
"""Tests for push/pull operation tools."""
@pytest.mark.asyncio
async def test_push_workspace_not_found(self):
"""Test push when workspace doesn't exist."""
from server import push
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await push.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
@pytest.mark.asyncio
async def test_pull_workspace_not_found(self):
"""Test pull when workspace doesn't exist."""
from server import pull
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await pull.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
class TestDiffLogOperations:
"""Tests for diff and log operation tools."""
@pytest.mark.asyncio
async def test_diff_workspace_not_found(self):
"""Test diff when workspace doesn't exist."""
from server import diff
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await diff.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
@pytest.mark.asyncio
async def test_log_workspace_not_found(self):
"""Test log when workspace doesn't exist."""
from server import log
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await log.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
class TestPROperations:
"""Tests for pull request operation tools."""
@pytest.mark.asyncio
async def test_create_pr_no_repo_url(self):
"""Test create PR when workspace has no repo URL."""
from models import WorkspaceInfo, WorkspaceState
from server import create_pull_request
mock_workspace = WorkspaceInfo(
project_id="test-project",
path="/tmp/test",
state=WorkspaceState.READY,
repo_url=None, # No repo URL
)
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
with patch("server._workspace_manager", mock_manager):
result = await create_pull_request.fn(
project_id="test-project",
agent_id="agent-1",
title="Test PR",
source_branch="feature",
target_branch="main",
)
assert result["success"] is False
assert "repository URL" in result["error"]
@pytest.mark.asyncio
async def test_list_prs_invalid_state(self):
"""Test list PRs with invalid state filter."""
from models import WorkspaceInfo, WorkspaceState
from server import list_pull_requests
mock_workspace = WorkspaceInfo(
project_id="test-project",
path="/tmp/test",
state=WorkspaceState.READY,
repo_url="https://gitea.test.com/owner/repo.git",
)
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
mock_provider = AsyncMock()
mock_provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
with (
patch("server._workspace_manager", mock_manager),
patch("server._get_provider_for_url", return_value=mock_provider),
):
result = await list_pull_requests.fn(
project_id="test-project",
agent_id="agent-1",
state="invalid-state",
)
assert result["success"] is False
assert "Invalid state" in result["error"]
@pytest.mark.asyncio
async def test_merge_pr_invalid_strategy(self):
"""Test merge PR with invalid strategy."""
from models import WorkspaceInfo, WorkspaceState
from server import merge_pull_request
mock_workspace = WorkspaceInfo(
project_id="test-project",
path="/tmp/test",
state=WorkspaceState.READY,
repo_url="https://gitea.test.com/owner/repo.git",
)
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
mock_provider = AsyncMock()
mock_provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
with (
patch("server._workspace_manager", mock_manager),
patch("server._get_provider_for_url", return_value=mock_provider),
):
result = await merge_pull_request.fn(
project_id="test-project",
agent_id="agent-1",
pr_number=42,
merge_strategy="invalid-strategy",
)
assert result["success"] is False
assert "Invalid strategy" in result["error"]
class TestWorkspaceOperations:
"""Tests for workspace operation tools."""
@pytest.mark.asyncio
async def test_get_workspace_not_found(self):
"""Test get workspace when it doesn't exist."""
from server import get_workspace
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await get_workspace.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
@pytest.mark.asyncio
async def test_lock_workspace_success(self):
"""Test successful workspace locking."""
from datetime import UTC, datetime, timedelta
from models import WorkspaceInfo, WorkspaceState
from server import lock_workspace
mock_workspace = WorkspaceInfo(
project_id="test-project",
path="/tmp/test",
state=WorkspaceState.LOCKED,
lock_holder="agent-1",
lock_expires=datetime.now(UTC) + timedelta(seconds=300),
)
mock_manager = AsyncMock()
mock_manager.lock_workspace = AsyncMock(return_value=True)
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
with patch("server._workspace_manager", mock_manager):
result = await lock_workspace.fn(
project_id="test-project",
agent_id="agent-1",
timeout=300,
)
assert result["success"] is True
assert result["lock_holder"] == "agent-1"
@pytest.mark.asyncio
async def test_unlock_workspace_success(self):
"""Test successful workspace unlocking."""
from server import unlock_workspace
mock_manager = AsyncMock()
mock_manager.unlock_workspace = AsyncMock(return_value=True)
with patch("server._workspace_manager", mock_manager):
result = await unlock_workspace.fn(
project_id="test-project",
agent_id="agent-1",
)
assert result["success"] is True
class TestJSONRPCEndpoint:
"""Tests for the JSON-RPC endpoint."""
def test_python_type_to_json_schema_str(self):
"""Test string type conversion."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(str)
assert result["type"] == "string"
def test_python_type_to_json_schema_int(self):
"""Test int type conversion."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(int)
assert result["type"] == "integer"
def test_python_type_to_json_schema_bool(self):
"""Test bool type conversion."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(bool)
assert result["type"] == "boolean"
def test_python_type_to_json_schema_list(self):
"""Test list type conversion."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(list[str])
assert result["type"] == "array"
assert result["items"]["type"] == "string"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,358 @@
"""
Tests for the workspace management module.
"""
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
from exceptions import WorkspaceLockedError, WorkspaceNotFoundError
from models import WorkspaceState
from workspace import FileLockManager, WorkspaceLock
class TestWorkspaceManager:
"""Tests for WorkspaceManager."""
@pytest.mark.asyncio
async def test_create_workspace(self, workspace_manager, valid_project_id):
"""Test creating a new workspace."""
workspace = await workspace_manager.create_workspace(valid_project_id)
assert workspace.project_id == valid_project_id
assert workspace.state == WorkspaceState.INITIALIZING
assert Path(workspace.path).exists()
@pytest.mark.asyncio
async def test_create_workspace_with_repo_url(
self, workspace_manager, valid_project_id, sample_repo_url
):
"""Test creating workspace with repository URL."""
workspace = await workspace_manager.create_workspace(
valid_project_id, repo_url=sample_repo_url
)
assert workspace.repo_url == sample_repo_url
@pytest.mark.asyncio
async def test_get_workspace(self, workspace_manager, valid_project_id):
"""Test getting an existing workspace."""
# Create first
await workspace_manager.create_workspace(valid_project_id)
# Get it
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace is not None
assert workspace.project_id == valid_project_id
@pytest.mark.asyncio
async def test_get_workspace_not_found(self, workspace_manager):
"""Test getting non-existent workspace."""
workspace = await workspace_manager.get_workspace("nonexistent")
assert workspace is None
@pytest.mark.asyncio
async def test_delete_workspace(self, workspace_manager, valid_project_id):
"""Test deleting a workspace."""
# Create first
workspace = await workspace_manager.create_workspace(valid_project_id)
workspace_path = Path(workspace.path)
assert workspace_path.exists()
# Delete
result = await workspace_manager.delete_workspace(valid_project_id)
assert result is True
assert not workspace_path.exists()
@pytest.mark.asyncio
async def test_delete_nonexistent_workspace(self, workspace_manager):
"""Test deleting non-existent workspace returns True."""
result = await workspace_manager.delete_workspace("nonexistent")
assert result is True
@pytest.mark.asyncio
async def test_list_workspaces(self, workspace_manager):
"""Test listing workspaces."""
# Create multiple workspaces
await workspace_manager.create_workspace("project-1")
await workspace_manager.create_workspace("project-2")
await workspace_manager.create_workspace("project-3")
workspaces = await workspace_manager.list_workspaces()
assert len(workspaces) >= 3
project_ids = [w.project_id for w in workspaces]
assert "project-1" in project_ids
assert "project-2" in project_ids
assert "project-3" in project_ids
class TestWorkspaceLocking:
"""Tests for workspace locking."""
@pytest.mark.asyncio
async def test_lock_workspace(
self, workspace_manager, valid_project_id, valid_agent_id
):
"""Test locking a workspace."""
await workspace_manager.create_workspace(valid_project_id)
result = await workspace_manager.lock_workspace(
valid_project_id, valid_agent_id, timeout=60
)
assert result is True
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.state == WorkspaceState.LOCKED
assert workspace.lock_holder == valid_agent_id
@pytest.mark.asyncio
async def test_lock_already_locked(self, workspace_manager, valid_project_id):
"""Test locking already-locked workspace by different holder."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, "agent-1", timeout=60)
with pytest.raises(WorkspaceLockedError):
await workspace_manager.lock_workspace(
valid_project_id, "agent-2", timeout=60
)
@pytest.mark.asyncio
async def test_lock_same_holder(
self, workspace_manager, valid_project_id, valid_agent_id
):
"""Test re-locking by same holder extends lock."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(
valid_project_id, valid_agent_id, timeout=60
)
# Same holder can re-lock
result = await workspace_manager.lock_workspace(
valid_project_id, valid_agent_id, timeout=120
)
assert result is True
@pytest.mark.asyncio
async def test_unlock_workspace(
self, workspace_manager, valid_project_id, valid_agent_id
):
"""Test unlocking a workspace."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
result = await workspace_manager.unlock_workspace(
valid_project_id, valid_agent_id
)
assert result is True
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.state == WorkspaceState.READY
assert workspace.lock_holder is None
@pytest.mark.asyncio
async def test_unlock_wrong_holder(self, workspace_manager, valid_project_id):
"""Test unlock fails with wrong holder."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, "agent-1")
with pytest.raises(WorkspaceLockedError):
await workspace_manager.unlock_workspace(valid_project_id, "agent-2")
@pytest.mark.asyncio
async def test_force_unlock(self, workspace_manager, valid_project_id):
"""Test force unlock works regardless of holder."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, "agent-1")
result = await workspace_manager.unlock_workspace(
valid_project_id, "admin", force=True
)
assert result is True
@pytest.mark.asyncio
async def test_lock_nonexistent_workspace(self, workspace_manager, valid_agent_id):
"""Test locking non-existent workspace raises error."""
with pytest.raises(WorkspaceNotFoundError):
await workspace_manager.lock_workspace("nonexistent", valid_agent_id)
class TestWorkspaceLockContextManager:
"""Tests for WorkspaceLock context manager."""
@pytest.mark.asyncio
async def test_lock_context_manager(
self, workspace_manager, valid_project_id, valid_agent_id
):
"""Test using WorkspaceLock as context manager."""
await workspace_manager.create_workspace(valid_project_id)
async with WorkspaceLock(
workspace_manager, valid_project_id, valid_agent_id
) as lock:
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.state == WorkspaceState.LOCKED
# After exiting context, should be unlocked
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.lock_holder is None
@pytest.mark.asyncio
async def test_lock_context_manager_error(
self, workspace_manager, valid_project_id, valid_agent_id
):
"""Test WorkspaceLock releases on exception."""
await workspace_manager.create_workspace(valid_project_id)
try:
async with WorkspaceLock(
workspace_manager, valid_project_id, valid_agent_id
):
raise ValueError("Test error")
except ValueError:
pass
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.lock_holder is None
class TestWorkspaceMetadata:
"""Tests for workspace metadata operations."""
@pytest.mark.asyncio
async def test_touch_workspace(self, workspace_manager, valid_project_id):
"""Test updating workspace access time."""
workspace = await workspace_manager.create_workspace(valid_project_id)
original_time = workspace.last_accessed
await workspace_manager.touch_workspace(valid_project_id)
updated = await workspace_manager.get_workspace(valid_project_id)
assert updated.last_accessed >= original_time
@pytest.mark.asyncio
async def test_update_workspace_branch(self, workspace_manager, valid_project_id):
"""Test updating workspace branch."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.update_workspace_branch(
valid_project_id, "feature-branch"
)
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.current_branch == "feature-branch"
class TestWorkspaceSize:
"""Tests for workspace size management."""
@pytest.mark.asyncio
async def test_check_size_within_limit(self, workspace_manager, valid_project_id):
"""Test size check passes for small workspace."""
await workspace_manager.create_workspace(valid_project_id)
# Should not raise
result = await workspace_manager.check_size_limit(valid_project_id)
assert result is True
@pytest.mark.asyncio
async def test_get_total_size(self, workspace_manager, valid_project_id):
"""Test getting total workspace size."""
workspace = await workspace_manager.create_workspace(valid_project_id)
# Add some content
content_file = Path(workspace.path) / "content.txt"
content_file.write_text("x" * 1000)
total_size = await workspace_manager.get_total_size()
assert total_size >= 1000
class TestFileLockManager:
"""Tests for file-based locking."""
def test_acquire_lock(self, temp_dir):
"""Test acquiring a file lock."""
manager = FileLockManager(temp_dir / "locks")
result = manager.acquire("test-key")
assert result is True
# Cleanup
manager.release("test-key")
def test_release_lock(self, temp_dir):
"""Test releasing a file lock."""
manager = FileLockManager(temp_dir / "locks")
manager.acquire("test-key")
result = manager.release("test-key")
assert result is True
def test_is_locked(self, temp_dir):
"""Test checking if locked."""
manager = FileLockManager(temp_dir / "locks")
assert manager.is_locked("test-key") is False
manager.acquire("test-key")
assert manager.is_locked("test-key") is True
manager.release("test-key")
def test_release_nonexistent_lock(self, temp_dir):
"""Test releasing a lock that doesn't exist."""
manager = FileLockManager(temp_dir / "locks")
# Should not raise
result = manager.release("nonexistent")
assert result is False
class TestWorkspaceCleanup:
"""Tests for workspace cleanup operations."""
@pytest.mark.asyncio
async def test_cleanup_stale_workspaces(self, workspace_manager, test_settings):
"""Test cleaning up stale workspaces."""
# Create workspace
workspace = await workspace_manager.create_workspace("stale-project")
# Manually set it as stale by updating metadata
await workspace_manager._update_metadata(
"stale-project",
last_accessed=(datetime.now(UTC) - timedelta(days=30)).isoformat(),
)
# Run cleanup
cleaned = await workspace_manager.cleanup_stale_workspaces()
assert cleaned >= 1
@pytest.mark.asyncio
async def test_delete_locked_workspace_blocked(
self, workspace_manager, valid_project_id, valid_agent_id
):
"""Test deleting locked workspace is blocked without force."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
with pytest.raises(WorkspaceLockedError):
await workspace_manager.delete_workspace(valid_project_id)
@pytest.mark.asyncio
async def test_delete_locked_workspace_force(
self, workspace_manager, valid_project_id, valid_agent_id
):
"""Test force deleting locked workspace."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
result = await workspace_manager.delete_workspace(valid_project_id, force=True)
assert result is True

1853
mcp-servers/git-ops/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,614 @@
"""
Workspace management for Git Operations MCP Server.
Handles isolated workspaces for each project, including creation,
locking, cleanup, and size management.
"""
import asyncio
import json
import logging
import shutil
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import Any
import aiofiles # type: ignore[import-untyped]
from filelock import FileLock, Timeout
from config import Settings, get_settings
from exceptions import (
WorkspaceLockedError,
WorkspaceNotFoundError,
WorkspaceSizeExceededError,
)
from models import WorkspaceInfo, WorkspaceState
logger = logging.getLogger(__name__)
# Metadata file name
WORKSPACE_METADATA_FILE = ".syndarix-workspace.json"
class WorkspaceManager:
"""
Manages git workspaces for projects.
Each project gets an isolated workspace directory for git operations.
Supports distributed locking via Redis or local file locks.
"""
def __init__(self, settings: Settings | None = None) -> None:
"""
Initialize WorkspaceManager.
Args:
settings: Optional settings override
"""
self.settings = settings or get_settings()
self.base_path = self.settings.workspace_base_path
self._ensure_base_path()
def _ensure_base_path(self) -> None:
"""Ensure the base workspace directory exists."""
self.base_path.mkdir(parents=True, exist_ok=True)
def _get_workspace_path(self, project_id: str) -> Path:
"""Get the path for a project workspace with path traversal protection."""
# Sanitize project ID for filesystem
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_id)
# Reject reserved names
reserved_names = {".", "..", "con", "prn", "aux", "nul"}
if safe_id.lower() in reserved_names:
raise ValueError(f"Invalid project ID: reserved name '{project_id}'")
# Construct path and verify it's within base_path (prevent path traversal)
workspace_path = (self.base_path / safe_id).resolve()
base_resolved = self.base_path.resolve()
if not workspace_path.is_relative_to(base_resolved):
raise ValueError(
f"Invalid project ID: path traversal detected '{project_id}'"
)
return workspace_path
def _get_lock_path(self, project_id: str) -> Path:
"""Get the lock file path for a workspace."""
return self._get_workspace_path(project_id) / ".lock"
def _get_metadata_path(self, project_id: str) -> Path:
"""Get the metadata file path for a workspace."""
return self._get_workspace_path(project_id) / WORKSPACE_METADATA_FILE
async def get_workspace(self, project_id: str) -> WorkspaceInfo | None:
"""
Get workspace info for a project.
Args:
project_id: Project identifier
Returns:
WorkspaceInfo or None if not found
"""
workspace_path = self._get_workspace_path(project_id)
if not workspace_path.exists():
return None
# Load metadata
metadata = await self._load_metadata(project_id)
# Calculate size
size_bytes = await self._calculate_size(workspace_path)
# Check lock status
lock_holder = None
lock_expires = None
if metadata:
lock_holder = metadata.get("lock_holder")
if metadata.get("lock_expires"):
lock_expires = datetime.fromisoformat(metadata["lock_expires"])
# Clear expired locks
if lock_expires < datetime.now(UTC):
lock_holder = None
lock_expires = None
# Determine state
state = WorkspaceState.READY
if lock_holder:
state = WorkspaceState.LOCKED
# Check if stale
last_accessed = datetime.now(UTC)
if metadata and metadata.get("last_accessed"):
last_accessed = datetime.fromisoformat(metadata["last_accessed"])
stale_threshold = datetime.now(UTC) - timedelta(
days=self.settings.workspace_stale_days
)
if last_accessed < stale_threshold:
state = WorkspaceState.STALE
return WorkspaceInfo(
project_id=project_id,
path=str(workspace_path),
state=state,
repo_url=metadata.get("repo_url") if metadata else None,
current_branch=metadata.get("current_branch") if metadata else None,
last_accessed=last_accessed,
size_bytes=size_bytes,
lock_holder=lock_holder,
lock_expires=lock_expires,
)
async def create_workspace(
self,
project_id: str,
repo_url: str | None = None,
) -> WorkspaceInfo:
"""
Create or get a workspace for a project.
Args:
project_id: Project identifier
repo_url: Optional repository URL
Returns:
WorkspaceInfo for the workspace
"""
workspace_path = self._get_workspace_path(project_id)
if workspace_path.exists():
# Workspace already exists, update metadata
await self._update_metadata(project_id, repo_url=repo_url)
workspace = await self.get_workspace(project_id)
if workspace:
return workspace
# Create workspace directory
workspace_path.mkdir(parents=True, exist_ok=True)
# Create initial metadata
metadata = {
"project_id": project_id,
"repo_url": repo_url,
"created_at": datetime.now(UTC).isoformat(),
"last_accessed": datetime.now(UTC).isoformat(),
}
await self._save_metadata(project_id, metadata)
return WorkspaceInfo(
project_id=project_id,
path=str(workspace_path),
state=WorkspaceState.INITIALIZING,
repo_url=repo_url,
last_accessed=datetime.now(UTC),
size_bytes=0,
)
async def delete_workspace(self, project_id: str, force: bool = False) -> bool:
"""
Delete a workspace.
Args:
project_id: Project identifier
force: Force delete even if locked
Returns:
True if deleted
"""
workspace_path = self._get_workspace_path(project_id)
if not workspace_path.exists():
return True
# Check lock
if not force:
workspace = await self.get_workspace(project_id)
if workspace and workspace.state == WorkspaceState.LOCKED:
raise WorkspaceLockedError(project_id, workspace.lock_holder)
try:
# Use shutil.rmtree for robust deletion
shutil.rmtree(workspace_path)
logger.info(f"Deleted workspace for project: {project_id}")
return True
except Exception as e:
logger.error(f"Failed to delete workspace {project_id}: {e}")
return False
async def lock_workspace(
self,
project_id: str,
holder: str,
timeout: int | None = None,
) -> bool:
"""
Acquire a lock on a workspace.
Args:
project_id: Project identifier
holder: Lock holder identifier (agent_id)
timeout: Lock timeout in seconds
Returns:
True if lock acquired
Raises:
WorkspaceNotFoundError: If workspace doesn't exist
WorkspaceLockedError: If already locked by another
"""
workspace = await self.get_workspace(project_id)
if workspace is None:
raise WorkspaceNotFoundError(project_id)
# Check if already locked by someone else
if workspace.state == WorkspaceState.LOCKED and workspace.lock_holder != holder:
# Check if lock expired
if workspace.lock_expires and workspace.lock_expires > datetime.now(UTC):
raise WorkspaceLockedError(project_id, workspace.lock_holder)
# Calculate lock expiry
lock_timeout = timeout or self.settings.workspace_lock_timeout
lock_expires = datetime.now(UTC) + timedelta(seconds=lock_timeout)
# Update metadata with lock info
await self._update_metadata(
project_id,
lock_holder=holder,
lock_expires=lock_expires.isoformat(),
)
logger.info(f"Workspace {project_id} locked by {holder}")
return True
async def unlock_workspace(
self,
project_id: str,
holder: str,
force: bool = False,
) -> bool:
"""
Release a lock on a workspace.
Args:
project_id: Project identifier
holder: Lock holder identifier
force: Force unlock regardless of holder
Returns:
True if unlocked
"""
workspace = await self.get_workspace(project_id)
if workspace is None:
raise WorkspaceNotFoundError(project_id)
# Verify holder
if not force and workspace.lock_holder and workspace.lock_holder != holder:
raise WorkspaceLockedError(project_id, workspace.lock_holder)
# Clear lock
await self._update_metadata(
project_id,
lock_holder=None,
lock_expires=None,
)
logger.info(f"Workspace {project_id} unlocked by {holder}")
return True
async def touch_workspace(self, project_id: str) -> None:
"""
Update last accessed time for a workspace.
Args:
project_id: Project identifier
"""
await self._update_metadata(
project_id,
last_accessed=datetime.now(UTC).isoformat(),
)
async def update_workspace_branch(
self,
project_id: str,
branch: str,
) -> None:
"""
Update the current branch in workspace metadata.
Args:
project_id: Project identifier
branch: Current branch name
"""
await self._update_metadata(
project_id,
current_branch=branch,
last_accessed=datetime.now(UTC).isoformat(),
)
async def check_size_limit(self, project_id: str) -> bool:
"""
Check if workspace exceeds size limit.
Args:
project_id: Project identifier
Returns:
True if within limits
Raises:
WorkspaceSizeExceededError: If size exceeds limit
"""
workspace_path = self._get_workspace_path(project_id)
if not workspace_path.exists():
return True
size_bytes = await self._calculate_size(workspace_path)
size_gb = size_bytes / (1024**3)
max_size_gb = self.settings.workspace_max_size_gb
if size_gb > max_size_gb:
raise WorkspaceSizeExceededError(project_id, size_gb, max_size_gb)
return True
async def list_workspaces(
self,
include_stale: bool = False,
) -> list[WorkspaceInfo]:
"""
List all workspaces.
Args:
include_stale: Include stale workspaces
Returns:
List of WorkspaceInfo
"""
workspaces: list[WorkspaceInfo] = []
if not self.base_path.exists():
return workspaces
for entry in self.base_path.iterdir():
if entry.is_dir() and not entry.name.startswith("."):
# Extract project_id from directory name
workspace = await self.get_workspace(entry.name)
if workspace:
if not include_stale and workspace.state == WorkspaceState.STALE:
continue
workspaces.append(workspace)
return workspaces
async def cleanup_stale_workspaces(self) -> int:
"""
Clean up stale workspaces.
Returns:
Number of workspaces cleaned up
"""
cleaned = 0
workspaces = await self.list_workspaces(include_stale=True)
for workspace in workspaces:
if workspace.state == WorkspaceState.STALE:
try:
await self.delete_workspace(workspace.project_id, force=True)
cleaned += 1
except Exception as e:
logger.error(
f"Failed to cleanup stale workspace {workspace.project_id}: {e}"
)
if cleaned > 0:
logger.info(f"Cleaned up {cleaned} stale workspaces")
return cleaned
async def get_total_size(self) -> int:
"""
Get total size of all workspaces.
Returns:
Total size in bytes
"""
return await self._calculate_size(self.base_path)
# Private methods
async def _load_metadata(self, project_id: str) -> dict[str, Any] | None:
"""Load workspace metadata from file."""
metadata_path = self._get_metadata_path(project_id)
if not metadata_path.exists():
return None
try:
async with aiofiles.open(metadata_path) as f:
content = await f.read()
return json.loads(content)
except Exception as e:
logger.warning(f"Failed to load metadata for {project_id}: {e}")
return None
async def _save_metadata(
self,
project_id: str,
metadata: dict[str, Any],
) -> None:
"""Save workspace metadata to file."""
metadata_path = self._get_metadata_path(project_id)
# Ensure parent directory exists
metadata_path.parent.mkdir(parents=True, exist_ok=True)
try:
async with aiofiles.open(metadata_path, "w") as f:
await f.write(json.dumps(metadata, indent=2))
except Exception as e:
logger.error(f"Failed to save metadata for {project_id}: {e}")
async def _update_metadata(
self,
project_id: str,
**updates: Any,
) -> None:
"""Update specific fields in workspace metadata."""
metadata = await self._load_metadata(project_id) or {}
# Handle None values (to clear fields)
for key, value in updates.items():
if value is None:
metadata.pop(key, None)
else:
metadata[key] = value
await self._save_metadata(project_id, metadata)
async def _calculate_size(self, path: Path) -> int:
"""Calculate total size of a directory."""
def _calc_size() -> int:
total = 0
try:
for entry in path.rglob("*"):
if entry.is_file():
try:
total += entry.stat().st_size
except OSError:
pass
except Exception:
pass
return total
# Run in executor for async compatibility
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _calc_size)
class WorkspaceLock:
"""
Context manager for workspace locking.
Provides automatic locking/unlocking with proper cleanup.
"""
def __init__(
self,
manager: WorkspaceManager,
project_id: str,
holder: str,
timeout: int | None = None,
) -> None:
"""
Initialize workspace lock.
Args:
manager: WorkspaceManager instance
project_id: Project identifier
holder: Lock holder identifier
timeout: Lock timeout in seconds
"""
self.manager = manager
self.project_id = project_id
self.holder = holder
self.timeout = timeout
self._acquired = False
async def __aenter__(self) -> "WorkspaceLock":
"""Acquire lock on enter."""
await self.manager.lock_workspace(
self.project_id,
self.holder,
self.timeout,
)
self._acquired = True
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Release lock on exit."""
if self._acquired:
try:
await self.manager.unlock_workspace(
self.project_id,
self.holder,
)
except Exception as e:
logger.warning(f"Failed to release lock for {self.project_id}: {e}")
class FileLockManager:
"""
File-based locking for single-instance deployments.
Uses filelock for local locking when Redis is not available.
"""
def __init__(self, lock_dir: Path) -> None:
"""
Initialize file lock manager.
Args:
lock_dir: Directory for lock files
"""
self.lock_dir = lock_dir
self.lock_dir.mkdir(parents=True, exist_ok=True)
self._locks: dict[str, FileLock] = {}
def _get_lock(self, key: str) -> FileLock:
"""Get or create a file lock for a key."""
if key not in self._locks:
lock_path = self.lock_dir / f"{key}.lock"
self._locks[key] = FileLock(lock_path)
return self._locks[key]
def acquire(
self,
key: str,
timeout: float = 10.0,
) -> bool:
"""
Acquire a lock.
Args:
key: Lock key
timeout: Timeout in seconds
Returns:
True if acquired
"""
lock = self._get_lock(key)
try:
lock.acquire(timeout=timeout)
return True
except Timeout:
return False
def release(self, key: str) -> bool:
"""
Release a lock.
Args:
key: Lock key
Returns:
True if released
"""
if key in self._locks:
try:
self._locks[key].release()
return True
except Exception:
pass
return False
def is_locked(self, key: str) -> bool:
"""Check if a key is locked."""
lock = self._get_lock(key)
return lock.is_locked

View File

@@ -1,4 +1,8 @@
.PHONY: help install install-dev lint lint-fix format type-check test test-cov validate clean run
.PHONY: help install install-dev lint lint-fix format format-check type-check test test-cov validate clean run
# Ensure commands in this project don't inherit an external Python virtualenv
# (prevents uv warnings about mismatched VIRTUAL_ENV when running from repo root)
unexport VIRTUAL_ENV
# Default target
help:
@@ -12,6 +16,7 @@ help:
@echo " make lint - Run Ruff linter"
@echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff"
@echo " make format-check - Check if code is formatted"
@echo " make type-check - Run mypy type checker"
@echo ""
@echo "Testing:"
@@ -19,7 +24,7 @@ help:
@echo " make test-cov - Run pytest with coverage"
@echo ""
@echo "All-in-one:"
@echo " make validate - Run lint, type-check, and tests"
@echo " make validate - Run all checks (lint + format + types)"
@echo ""
@echo "Running:"
@echo " make run - Run the server locally"
@@ -49,6 +54,10 @@ format:
@echo "Formatting code..."
@uv run ruff format .
format-check:
@echo "Checking code formatting..."
@uv run ruff format --check .
type-check:
@echo "Running mypy..."
@uv run mypy . --ignore-missing-imports
@@ -62,8 +71,9 @@ test-cov:
@echo "Running tests with coverage..."
@uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
# All-in-one validation
validate: lint type-check test
validate: lint format-check type-check
@echo "All validations passed!"
# Running

View File

@@ -184,7 +184,12 @@ class ChunkerFactory:
if file_type:
if file_type == FileType.MARKDOWN:
return self._get_markdown_chunker()
elif file_type in (FileType.TEXT, FileType.JSON, FileType.YAML, FileType.TOML):
elif file_type in (
FileType.TEXT,
FileType.JSON,
FileType.YAML,
FileType.TOML,
):
return self._get_text_chunker()
else:
# Code files
@@ -193,7 +198,9 @@ class ChunkerFactory:
# Default to text chunker
return self._get_text_chunker()
def get_chunker_for_path(self, source_path: str) -> tuple[BaseChunker, FileType | None]:
def get_chunker_for_path(
self, source_path: str
) -> tuple[BaseChunker, FileType | None]:
"""
Get chunker based on file path extension.

View File

@@ -151,7 +151,7 @@ class CodeChunker(BaseChunker):
for struct_type, pattern in patterns.items():
for match in pattern.finditer(content):
# Convert character position to line number
line_num = content[:match.start()].count("\n")
line_num = content[: match.start()].count("\n")
boundaries.append((line_num, struct_type))
if not boundaries:

View File

@@ -69,9 +69,7 @@ class MarkdownChunker(BaseChunker):
if not sections:
# No headings, chunk as plain text
return self._chunk_text_block(
content, source_path, file_type, metadata, []
)
return self._chunk_text_block(content, source_path, file_type, metadata, [])
chunks: list[Chunk] = []
heading_stack: list[tuple[int, str]] = [] # (level, text)
@@ -292,7 +290,10 @@ class MarkdownChunker(BaseChunker):
)
# Overlap: include last paragraph if it fits
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
if (
current_content
and self.count_tokens(current_content[-1]) <= self.chunk_overlap
):
current_content = [current_content[-1]]
current_tokens = self.count_tokens(current_content[-1])
else:
@@ -341,12 +342,14 @@ class MarkdownChunker(BaseChunker):
# Start of code block - save previous paragraph
if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para)
paragraphs.append({
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": i - 1,
})
paragraphs.append(
{
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": i - 1,
}
)
current_para = [line]
para_start = i
in_code_block = True
@@ -360,12 +363,14 @@ class MarkdownChunker(BaseChunker):
if not line.strip():
if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para)
paragraphs.append({
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": i - 1,
})
paragraphs.append(
{
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": i - 1,
}
)
current_para = []
para_start = i + 1
else:
@@ -376,12 +381,14 @@ class MarkdownChunker(BaseChunker):
# Final paragraph
if current_para and any(p.strip() for p in current_para):
para_content = "\n".join(current_para)
paragraphs.append({
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": len(lines) - 1,
})
paragraphs.append(
{
"content": para_content,
"tokens": self.count_tokens(para_content),
"start_line": para_start,
"end_line": len(lines) - 1,
}
)
return paragraphs
@@ -448,7 +455,10 @@ class MarkdownChunker(BaseChunker):
)
# Overlap with last sentence
if current_content and self.count_tokens(current_content[-1]) <= self.chunk_overlap:
if (
current_content
and self.count_tokens(current_content[-1]) <= self.chunk_overlap
):
current_content = [current_content[-1]]
current_tokens = self.count_tokens(current_content[-1])
else:

View File

@@ -79,9 +79,7 @@ class TextChunker(BaseChunker):
)
# Fall back to sentence-based chunking
return self._chunk_by_sentences(
content, source_path, file_type, metadata
)
return self._chunk_by_sentences(content, source_path, file_type, metadata)
def _split_paragraphs(self, content: str) -> list[dict[str, Any]]:
"""Split content into paragraphs."""
@@ -97,12 +95,14 @@ class TextChunker(BaseChunker):
continue
para_lines = para.count("\n") + 1
paragraphs.append({
"content": para,
"tokens": self.count_tokens(para),
"start_line": line_num,
"end_line": line_num + para_lines - 1,
})
paragraphs.append(
{
"content": para,
"tokens": self.count_tokens(para),
"start_line": line_num,
"end_line": line_num + para_lines - 1,
}
)
line_num += para_lines + 1 # +1 for blank line between paragraphs
return paragraphs
@@ -172,7 +172,10 @@ class TextChunker(BaseChunker):
# Overlap: keep last paragraph if small enough
overlap_para = None
if current_paras and self.count_tokens(current_paras[-1]) <= self.chunk_overlap:
if (
current_paras
and self.count_tokens(current_paras[-1]) <= self.chunk_overlap
):
overlap_para = current_paras[-1]
current_paras = [overlap_para] if overlap_para else []
@@ -266,7 +269,10 @@ class TextChunker(BaseChunker):
# Overlap: keep last sentence if small enough
overlap = None
if current_sentences and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap:
if (
current_sentences
and self.count_tokens(current_sentences[-1]) <= self.chunk_overlap
):
overlap = current_sentences[-1]
current_sentences = [overlap] if overlap else []
@@ -317,14 +323,10 @@ class TextChunker(BaseChunker):
sentences = self._split_sentences(text)
if len(sentences) > 1:
return self._chunk_by_sentences(
text, source_path, file_type, metadata
)
return self._chunk_by_sentences(text, source_path, file_type, metadata)
# Fall back to word-based splitting
return self._chunk_by_words(
text, source_path, file_type, metadata, base_line
)
return self._chunk_by_words(text, source_path, file_type, metadata, base_line)
def _chunk_by_words(
self,

View File

@@ -328,14 +328,18 @@ class CollectionManager:
"source_path": chunk.source_path or source_path,
"start_line": chunk.start_line,
"end_line": chunk.end_line,
"file_type": effective_file_type.value if (effective_file_type := chunk.file_type or file_type) else None,
"file_type": effective_file_type.value
if (effective_file_type := chunk.file_type or file_type)
else None,
}
embeddings_data.append((
chunk.content,
embedding,
chunk.chunk_type,
chunk_metadata,
))
embeddings_data.append(
(
chunk.content,
embedding,
chunk.chunk_type,
chunk_metadata,
)
)
# Atomically replace old embeddings with new ones
_, chunk_ids = await self.database.replace_source_embeddings(

View File

@@ -57,6 +57,9 @@ class DatabaseManager:
async def initialize(self) -> None:
"""Initialize connection pool and create schema."""
try:
# First, create pgvector extension (required before register_vector in pool init)
await self._ensure_pgvector_extension()
self._pool = await asyncpg.create_pool(
self._settings.database_url,
min_size=2,
@@ -66,7 +69,7 @@ class DatabaseManager:
)
logger.info("Database pool created successfully")
# Create schema
# Create schema (tables and indexes)
await self._create_schema()
logger.info("Database schema initialized")
@@ -77,6 +80,19 @@ class DatabaseManager:
cause=e,
)
async def _ensure_pgvector_extension(self) -> None:
"""Ensure pgvector extension exists before pool creation.
This must run before creating the connection pool because
register_vector() in _init_connection requires the extension to exist.
"""
conn = await asyncpg.connect(self._settings.database_url)
try:
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
logger.info("pgvector extension ensured")
finally:
await conn.close()
async def _init_connection(self, conn: asyncpg.Connection) -> None: # type: ignore[type-arg]
"""Initialize a connection with pgvector support."""
await register_vector(conn)
@@ -84,8 +100,7 @@ class DatabaseManager:
async def _create_schema(self) -> None:
"""Create database schema if not exists."""
async with self.pool.acquire() as conn:
# Enable pgvector extension
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
# Note: pgvector extension is created in _ensure_pgvector_extension()
# Create main embeddings table
await conn.execute("""
@@ -286,7 +301,14 @@ class DatabaseManager:
try:
async with self.acquire() as conn, conn.transaction():
# Wrap in transaction for all-or-nothing batch semantics
for project_id, collection, content, embedding, chunk_type, metadata in embeddings:
for (
project_id,
collection,
content,
embedding,
chunk_type,
metadata,
) in embeddings:
content_hash = self.compute_content_hash(content)
source_path = metadata.get("source_path")
start_line = metadata.get("start_line")
@@ -397,7 +419,9 @@ class DatabaseManager:
source_path=row["source_path"],
start_line=row["start_line"],
end_line=row["end_line"],
file_type=FileType(row["file_type"]) if row["file_type"] else None,
file_type=FileType(row["file_type"])
if row["file_type"]
else None,
metadata=row["metadata"] or {},
content_hash=row["content_hash"],
created_at=row["created_at"],
@@ -476,7 +500,9 @@ class DatabaseManager:
source_path=row["source_path"],
start_line=row["start_line"],
end_line=row["end_line"],
file_type=FileType(row["file_type"]) if row["file_type"] else None,
file_type=FileType(row["file_type"])
if row["file_type"]
else None,
metadata=row["metadata"] or {},
content_hash=row["content_hash"],
created_at=row["created_at"],

View File

@@ -214,9 +214,7 @@ class EmbeddingGenerator:
return cached
# Generate via LLM Gateway
embeddings = await self._call_llm_gateway(
[text], project_id, agent_id
)
embeddings = await self._call_llm_gateway([text], project_id, agent_id)
if not embeddings:
raise EmbeddingGenerationError(
@@ -277,9 +275,7 @@ class EmbeddingGenerator:
for i in range(0, len(texts_to_embed), batch_size):
batch = texts_to_embed[i : i + batch_size]
batch_embeddings = await self._call_llm_gateway(
batch, project_id, agent_id
)
batch_embeddings = await self._call_llm_gateway(batch, project_id, agent_id)
new_embeddings.extend(batch_embeddings)
# Validate dimensions

View File

@@ -149,12 +149,8 @@ class IngestRequest(BaseModel):
source_path: str | None = Field(
default=None, description="Source file path for reference"
)
collection: str = Field(
default="default", description="Collection to store in"
)
chunk_type: ChunkType = Field(
default=ChunkType.TEXT, description="Type of content"
)
collection: str = Field(default="default", description="Collection to store in")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="Type of content")
file_type: FileType | None = Field(
default=None, description="File type for code chunking"
)
@@ -255,12 +251,8 @@ class DeleteRequest(BaseModel):
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
source_path: str | None = Field(
default=None, description="Delete by source path"
)
collection: str | None = Field(
default=None, description="Delete entire collection"
)
source_path: str | None = Field(default=None, description="Delete by source path")
collection: str | None = Field(default=None, description="Delete entire collection")
chunk_ids: list[str] | None = Field(
default=None, description="Delete specific chunks"
)

View File

@@ -145,8 +145,7 @@ class SearchEngine:
# Filter by threshold (keyword search scores are normalized)
filtered = [
(emb, score) for emb, score in results
if score >= request.threshold
(emb, score) for emb, score in results if score >= request.threshold
]
return [
@@ -204,10 +203,9 @@ class SearchEngine:
)
# Filter by threshold and limit
filtered = [
result for result in fused
if result.score >= request.threshold
][:request.limit]
filtered = [result for result in fused if result.score >= request.threshold][
: request.limit
]
return filtered

View File

@@ -93,6 +93,7 @@ def _validate_source_path(value: str | None) -> str | None:
return None
# Configure logging
logging.basicConfig(
level=logging.INFO,
@@ -213,7 +214,9 @@ async def health_check() -> dict[str, Any]:
if response.status_code == 200:
status["dependencies"]["llm_gateway"] = "connected"
else:
status["dependencies"]["llm_gateway"] = f"unhealthy (status {response.status_code})"
status["dependencies"]["llm_gateway"] = (
f"unhealthy (status {response.status_code})"
)
is_degraded = True
else:
status["dependencies"]["llm_gateway"] = "not initialized"
@@ -328,7 +331,9 @@ def _get_tool_schema(func: Any) -> dict[str, Any]:
}
def _register_tool(name: str, tool_or_func: Any, description: str | None = None) -> None:
def _register_tool(
name: str, tool_or_func: Any, description: str | None = None
) -> None:
"""Register a tool in the registry.
Handles both raw functions and FastMCP FunctionTool objects.
@@ -337,7 +342,11 @@ def _register_tool(name: str, tool_or_func: Any, description: str | None = None)
if hasattr(tool_or_func, "fn"):
func = tool_or_func.fn
# Use FunctionTool's description if available
if not description and hasattr(tool_or_func, "description") and tool_or_func.description:
if (
not description
and hasattr(tool_or_func, "description")
and tool_or_func.description
):
description = tool_or_func.description
else:
func = tool_or_func
@@ -358,11 +367,13 @@ async def list_mcp_tools() -> dict[str, Any]:
"""
tools = []
for name, info in _tool_registry.items():
tools.append({
"name": name,
"description": info["description"],
"inputSchema": info["schema"],
})
tools.append(
{
"name": name,
"description": info["description"],
"inputSchema": info["schema"],
}
)
return {"tools": tools}
@@ -410,7 +421,10 @@ async def mcp_rpc(request: Request) -> JSONResponse:
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid Request: jsonrpc must be '2.0'"},
"error": {
"code": -32600,
"message": "Invalid Request: jsonrpc must be '2.0'",
},
"id": request_id,
},
)
@@ -420,7 +434,10 @@ async def mcp_rpc(request: Request) -> JSONResponse:
status_code=400,
content={
"jsonrpc": "2.0",
"error": {"code": -32600, "message": "Invalid Request: method is required"},
"error": {
"code": -32600,
"message": "Invalid Request: method is required",
},
"id": request_id,
},
)
@@ -528,11 +545,23 @@ async def search_knowledge(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
# Parse search type
try:
@@ -644,13 +673,29 @@ async def ingest_content(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
# Validate content size to prevent DoS
settings = get_settings()
@@ -750,13 +795,29 @@ async def delete_content(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if collection and (error := _validate_collection(collection)):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
request = DeleteRequest(
project_id=project_id,
@@ -803,9 +864,17 @@ async def list_collections(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
result = await _collections.list_collections(project_id) # type: ignore[union-attr]
@@ -856,11 +925,23 @@ async def get_collection_stats(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
stats = await _collections.get_collection_stats(project_id, collection) # type: ignore[union-attr]
@@ -874,8 +955,12 @@ async def get_collection_stats(
"avg_chunk_size": stats.avg_chunk_size,
"chunk_types": stats.chunk_types,
"file_types": stats.file_types,
"oldest_chunk": stats.oldest_chunk.isoformat() if stats.oldest_chunk else None,
"newest_chunk": stats.newest_chunk.isoformat() if stats.newest_chunk else None,
"oldest_chunk": stats.oldest_chunk.isoformat()
if stats.oldest_chunk
else None,
"newest_chunk": stats.newest_chunk.isoformat()
if stats.newest_chunk
else None,
}
except KnowledgeBaseError as e:
@@ -925,13 +1010,29 @@ async def update_document(
try:
# Validate inputs
if error := _validate_id(project_id, "project_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_id(agent_id, "agent_id"):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_collection(collection):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
if error := _validate_source_path(source_path):
return {"success": False, "error": error, "code": ErrorCode.INVALID_REQUEST.value}
return {
"success": False,
"error": error,
"code": ErrorCode.INVALID_REQUEST.value,
}
# Validate content size to prevent DoS
settings = get_settings()

View File

@@ -83,7 +83,9 @@ def mock_embeddings():
return [0.1] * 1536
mock_emb.generate = AsyncMock(return_value=fake_embedding())
mock_emb.generate_batch = AsyncMock(side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts])
mock_emb.generate_batch = AsyncMock(
side_effect=lambda texts, **_kwargs: [fake_embedding() for _ in texts]
)
return mock_emb
@@ -137,7 +139,7 @@ async def async_function() -> None:
@pytest.fixture
def sample_markdown():
"""Sample Markdown content for chunking tests."""
return '''# Project Documentation
return """# Project Documentation
This is the main documentation for our project.
@@ -182,20 +184,20 @@ The search endpoint allows you to query the knowledge base.
## Contributing
We welcome contributions! Please see our contributing guide.
'''
"""
@pytest.fixture
def sample_text():
"""Sample plain text for chunking tests."""
return '''The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
return """The quick brown fox jumps over the lazy dog. This is a sample text that we use for testing the text chunking functionality. It contains multiple sentences that should be properly split into chunks.
Each paragraph represents a logical unit of text. The chunker should try to respect paragraph boundaries when possible. This helps maintain context and readability.
When chunks need to be split mid-paragraph, the chunker should prefer sentence boundaries. This ensures that each chunk contains complete thoughts and is useful for retrieval.
The final paragraph tests edge cases. What happens with short paragraphs? Do they get merged with adjacent content? Let's find out!
'''
"""
@pytest.fixture

View File

@@ -1,7 +1,6 @@
"""Tests for chunking module."""
class TestBaseChunker:
"""Tests for base chunker functionality."""
@@ -149,7 +148,7 @@ class TestMarkdownChunker:
"""Test that chunker respects heading hierarchy."""
from chunking.markdown import MarkdownChunker
markdown = '''# Main Title
markdown = """# Main Title
Introduction paragraph.
@@ -164,7 +163,7 @@ More detailed content.
## Section Two
Content for section two.
'''
"""
chunker = MarkdownChunker(
chunk_size=200,
@@ -188,7 +187,7 @@ Content for section two.
"""Test handling of code blocks in markdown."""
from chunking.markdown import MarkdownChunker
markdown = '''# Code Example
markdown = """# Code Example
Here's some code:
@@ -198,7 +197,7 @@ def hello():
```
End of example.
'''
"""
chunker = MarkdownChunker(
chunk_size=500,
@@ -256,12 +255,12 @@ class TestTextChunker:
"""Test that chunker respects paragraph boundaries."""
from chunking.text import TextChunker
text = '''First paragraph with some content.
text = """First paragraph with some content.
Second paragraph with different content.
Third paragraph to test chunking behavior.
'''
"""
chunker = TextChunker(
chunk_size=100,

View File

@@ -67,10 +67,14 @@ class TestCollectionManager:
assert result.embeddings_generated == 0
@pytest.mark.asyncio
async def test_ingest_error_handling(self, collection_manager, sample_ingest_request):
async def test_ingest_error_handling(
self, collection_manager, sample_ingest_request
):
"""Test ingest error handling."""
# Make embedding generation fail
collection_manager._embeddings.generate_batch.side_effect = Exception("Embedding error")
collection_manager._embeddings.generate_batch.side_effect = Exception(
"Embedding error"
)
result = await collection_manager.ingest(sample_ingest_request)
@@ -182,7 +186,9 @@ class TestCollectionManager:
)
collection_manager._database.get_collection_stats.return_value = expected_stats
stats = await collection_manager.get_collection_stats("proj-123", "test-collection")
stats = await collection_manager.get_collection_stats(
"proj-123", "test-collection"
)
assert stats.chunk_count == 100
assert stats.unique_sources == 10

View File

@@ -17,19 +17,15 @@ class TestEmbeddingGenerator:
response.raise_for_status = MagicMock()
response.json.return_value = {
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 1536]
})
}
]
"content": [{"text": json.dumps({"embeddings": [[0.1] * 1536]})}]
}
}
return response
@pytest.mark.asyncio
async def test_generate_single_embedding(self, settings, mock_redis, mock_http_response):
async def test_generate_single_embedding(
self, settings, mock_redis, mock_http_response
):
"""Test generating a single embedding."""
from embeddings import EmbeddingGenerator
@@ -67,9 +63,9 @@ class TestEmbeddingGenerator:
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]
})
"text": json.dumps(
{"embeddings": [[0.1] * 1536, [0.2] * 1536, [0.3] * 1536]}
)
}
]
}
@@ -166,9 +162,11 @@ class TestEmbeddingGenerator:
"result": {
"content": [
{
"text": json.dumps({
"embeddings": [[0.1] * 768] # Wrong dimension
})
"text": json.dumps(
{
"embeddings": [[0.1] * 768] # Wrong dimension
}
)
}
]
}

View File

@@ -1,7 +1,6 @@
"""Tests for exception classes."""
class TestErrorCode:
"""Tests for ErrorCode enum."""
@@ -10,8 +9,13 @@ class TestErrorCode:
from exceptions import ErrorCode
assert ErrorCode.UNKNOWN_ERROR.value == "KB_UNKNOWN_ERROR"
assert ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
assert ErrorCode.EMBEDDING_GENERATION_ERROR.value == "KB_EMBEDDING_GENERATION_ERROR"
assert (
ErrorCode.DATABASE_CONNECTION_ERROR.value == "KB_DATABASE_CONNECTION_ERROR"
)
assert (
ErrorCode.EMBEDDING_GENERATION_ERROR.value
== "KB_EMBEDDING_GENERATION_ERROR"
)
assert ErrorCode.CHUNKING_ERROR.value == "KB_CHUNKING_ERROR"
assert ErrorCode.SEARCH_ERROR.value == "KB_SEARCH_ERROR"
assert ErrorCode.COLLECTION_NOT_FOUND.value == "KB_COLLECTION_NOT_FOUND"

View File

@@ -59,7 +59,9 @@ class TestSearchEngine:
]
@pytest.mark.asyncio
async def test_semantic_search(self, search_engine, sample_search_request, sample_db_results):
async def test_semantic_search(
self, search_engine, sample_search_request, sample_db_results
):
"""Test semantic search."""
from models import SearchType
@@ -74,7 +76,9 @@ class TestSearchEngine:
search_engine._database.semantic_search.assert_called_once()
@pytest.mark.asyncio
async def test_keyword_search(self, search_engine, sample_search_request, sample_db_results):
async def test_keyword_search(
self, search_engine, sample_search_request, sample_db_results
):
"""Test keyword search."""
from models import SearchType
@@ -88,7 +92,9 @@ class TestSearchEngine:
search_engine._database.keyword_search.assert_called_once()
@pytest.mark.asyncio
async def test_hybrid_search(self, search_engine, sample_search_request, sample_db_results):
async def test_hybrid_search(
self, search_engine, sample_search_request, sample_db_results
):
"""Test hybrid search."""
from models import SearchType
@@ -105,7 +111,9 @@ class TestSearchEngine:
assert len(response.results) >= 1
@pytest.mark.asyncio
async def test_search_with_collection_filter(self, search_engine, sample_search_request, sample_db_results):
async def test_search_with_collection_filter(
self, search_engine, sample_search_request, sample_db_results
):
"""Test search with collection filter."""
from models import SearchType
@@ -120,7 +128,9 @@ class TestSearchEngine:
assert call_args.kwargs["collection"] == "specific-collection"
@pytest.mark.asyncio
async def test_search_with_file_type_filter(self, search_engine, sample_search_request, sample_db_results):
async def test_search_with_file_type_filter(
self, search_engine, sample_search_request, sample_db_results
):
"""Test search with file type filter."""
from models import FileType, SearchType
@@ -135,7 +145,9 @@ class TestSearchEngine:
assert call_args.kwargs["file_types"] == [FileType.PYTHON]
@pytest.mark.asyncio
async def test_search_respects_limit(self, search_engine, sample_search_request, sample_db_results):
async def test_search_respects_limit(
self, search_engine, sample_search_request, sample_db_results
):
"""Test that search respects result limit."""
from models import SearchType
@@ -148,7 +160,9 @@ class TestSearchEngine:
assert len(response.results) <= 1
@pytest.mark.asyncio
async def test_search_records_time(self, search_engine, sample_search_request, sample_db_results):
async def test_search_records_time(
self, search_engine, sample_search_request, sample_db_results
):
"""Test that search records time."""
from models import SearchType
@@ -203,13 +217,21 @@ class TestReciprocalRankFusion:
from models import SearchResult
semantic = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
SearchResult(
id="a", content="A", score=0.9, chunk_type="code", collection="default"
),
SearchResult(
id="b", content="B", score=0.8, chunk_type="code", collection="default"
),
]
keyword = [
SearchResult(id="b", content="B", score=0.85, chunk_type="code", collection="default"),
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
SearchResult(
id="b", content="B", score=0.85, chunk_type="code", collection="default"
),
SearchResult(
id="c", content="C", score=0.7, chunk_type="code", collection="default"
),
]
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)
@@ -230,19 +252,23 @@ class TestReciprocalRankFusion:
# Same results in same order
results = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
SearchResult(
id="a", content="A", score=0.9, chunk_type="code", collection="default"
),
]
# High semantic weight
fused_semantic_heavy = search_engine._reciprocal_rank_fusion(
results, [],
results,
[],
semantic_weight=0.9,
keyword_weight=0.1,
)
# High keyword weight
fused_keyword_heavy = search_engine._reciprocal_rank_fusion(
[], results,
[],
results,
semantic_weight=0.1,
keyword_weight=0.9,
)
@@ -256,12 +282,18 @@ class TestReciprocalRankFusion:
from models import SearchResult
semantic = [
SearchResult(id="a", content="A", score=0.9, chunk_type="code", collection="default"),
SearchResult(id="b", content="B", score=0.8, chunk_type="code", collection="default"),
SearchResult(
id="a", content="A", score=0.9, chunk_type="code", collection="default"
),
SearchResult(
id="b", content="B", score=0.8, chunk_type="code", collection="default"
),
]
keyword = [
SearchResult(id="c", content="C", score=0.7, chunk_type="code", collection="default"),
SearchResult(
id="c", content="C", score=0.7, chunk_type="code", collection="default"
),
]
fused = search_engine._reciprocal_rank_fusion(semantic, keyword)

View File

@@ -1,4 +1,8 @@
.PHONY: help install install-dev lint lint-fix format type-check test test-cov validate clean run
.PHONY: help install install-dev lint lint-fix format format-check type-check test test-cov validate clean run
# Ensure commands in this project don't inherit an external Python virtualenv
# (prevents uv warnings about mismatched VIRTUAL_ENV when running from repo root)
unexport VIRTUAL_ENV
# Default target
help:
@@ -12,6 +16,7 @@ help:
@echo " make lint - Run Ruff linter"
@echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff"
@echo " make format-check - Check if code is formatted"
@echo " make type-check - Run mypy type checker"
@echo ""
@echo "Testing:"
@@ -19,7 +24,7 @@ help:
@echo " make test-cov - Run pytest with coverage"
@echo ""
@echo "All-in-one:"
@echo " make validate - Run lint, type-check, and tests"
@echo " make validate - Run all checks (lint + format + types)"
@echo ""
@echo "Running:"
@echo " make run - Run the server locally"
@@ -49,6 +54,10 @@ format:
@echo "Formatting code..."
@uv run ruff format .
format-check:
@echo "Checking code formatting..."
@uv run ruff format --check .
type-check:
@echo "Running mypy..."
@uv run mypy . --ignore-missing-imports
@@ -63,7 +72,7 @@ test-cov:
@uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
# All-in-one validation
validate: lint type-check test
validate: lint format-check type-check
@echo "All validations passed!"
# Running

View File

@@ -111,7 +111,10 @@ class CircuitBreaker:
if self._state == CircuitState.OPEN:
time_in_open = time.time() - self._stats.state_changed_at
# Double-check state after time calculation (for thread safety)
if time_in_open >= self.recovery_timeout and self._state == CircuitState.OPEN:
if (
time_in_open >= self.recovery_timeout
and self._state == CircuitState.OPEN
):
self._transition_to(CircuitState.HALF_OPEN)
logger.info(
f"Circuit {self.name} transitioned to HALF_OPEN "