forked from cardosofelipe/fast-next-template
Compare commits
3 Commits
4ad3d20cf2
...
76d7de5334
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
76d7de5334 | ||
|
|
1779239c07 | ||
|
|
9dfa76aa41 |
13
Makefile
13
Makefile
@@ -47,6 +47,7 @@ help:
|
|||||||
@echo " cd backend && make help - Backend-specific commands"
|
@echo " cd backend && make help - Backend-specific commands"
|
||||||
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
|
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
|
||||||
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
|
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
|
||||||
|
@echo " cd mcp-servers/git-ops && make - Git Operations commands"
|
||||||
@echo " cd frontend && npm run - Frontend-specific commands"
|
@echo " cd frontend && npm run - Frontend-specific commands"
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@@ -138,6 +139,9 @@ test-mcp:
|
|||||||
@echo ""
|
@echo ""
|
||||||
@echo "=== Knowledge Base ==="
|
@echo "=== Knowledge Base ==="
|
||||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
|
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
|
||||||
|
@echo ""
|
||||||
|
@echo "=== Git Operations ==="
|
||||||
|
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v
|
||||||
|
|
||||||
test-frontend:
|
test-frontend:
|
||||||
@echo "Running frontend tests..."
|
@echo "Running frontend tests..."
|
||||||
@@ -158,6 +162,9 @@ test-cov:
|
|||||||
@echo ""
|
@echo ""
|
||||||
@echo "=== Knowledge Base Coverage ==="
|
@echo "=== Knowledge Base Coverage ==="
|
||||||
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||||
|
@echo ""
|
||||||
|
@echo "=== Git Operations Coverage ==="
|
||||||
|
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing
|
||||||
|
|
||||||
test-integration:
|
test-integration:
|
||||||
@echo "Running MCP integration tests..."
|
@echo "Running MCP integration tests..."
|
||||||
@@ -178,6 +185,9 @@ format-all:
|
|||||||
@echo "Formatting Knowledge Base..."
|
@echo "Formatting Knowledge Base..."
|
||||||
@cd mcp-servers/knowledge-base && make format
|
@cd mcp-servers/knowledge-base && make format
|
||||||
@echo ""
|
@echo ""
|
||||||
|
@echo "Formatting Git Operations..."
|
||||||
|
@cd mcp-servers/git-ops && make format
|
||||||
|
@echo ""
|
||||||
@echo "Formatting frontend..."
|
@echo "Formatting frontend..."
|
||||||
@cd frontend && npm run format
|
@cd frontend && npm run format
|
||||||
@echo ""
|
@echo ""
|
||||||
@@ -197,6 +207,9 @@ validate:
|
|||||||
@echo "Validating Knowledge Base..."
|
@echo "Validating Knowledge Base..."
|
||||||
@cd mcp-servers/knowledge-base && make validate
|
@cd mcp-servers/knowledge-base && make validate
|
||||||
@echo ""
|
@echo ""
|
||||||
|
@echo "Validating Git Operations..."
|
||||||
|
@cd mcp-servers/git-ops && make validate
|
||||||
|
@echo ""
|
||||||
@echo "All validations passed!"
|
@echo "All validations passed!"
|
||||||
|
|
||||||
validate-all: validate
|
validate-all: validate
|
||||||
|
|||||||
@@ -96,6 +96,38 @@ services:
|
|||||||
- app-network
|
- app-network
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
|
mcp-git-ops:
|
||||||
|
build:
|
||||||
|
context: ./mcp-servers/git-ops
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
ports:
|
||||||
|
- "8003:8003"
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
environment:
|
||||||
|
# GIT_OPS_ prefix required by pydantic-settings config
|
||||||
|
- GIT_OPS_HOST=0.0.0.0
|
||||||
|
- GIT_OPS_PORT=8003
|
||||||
|
- GIT_OPS_REDIS_URL=redis://redis:6379/3
|
||||||
|
- GIT_OPS_GITEA_BASE_URL=${GITEA_BASE_URL}
|
||||||
|
- GIT_OPS_GITEA_TOKEN=${GITEA_TOKEN}
|
||||||
|
- GIT_OPS_GITHUB_TOKEN=${GITHUB_TOKEN}
|
||||||
|
- ENVIRONMENT=development
|
||||||
|
volumes:
|
||||||
|
- git_workspaces_dev:/workspaces
|
||||||
|
depends_on:
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 10s
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
backend:
|
backend:
|
||||||
build:
|
build:
|
||||||
context: ./backend
|
context: ./backend
|
||||||
@@ -119,6 +151,7 @@ services:
|
|||||||
# MCP Server URLs
|
# MCP Server URLs
|
||||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||||
|
- GIT_OPS_URL=http://mcp-git-ops:8003
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
@@ -128,6 +161,8 @@ services:
|
|||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
mcp-knowledge-base:
|
mcp-knowledge-base:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
mcp-git-ops:
|
||||||
|
condition: service_healthy
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
@@ -155,6 +190,7 @@ services:
|
|||||||
# MCP Server URLs (agents need access to MCP)
|
# MCP Server URLs (agents need access to MCP)
|
||||||
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
|
||||||
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
|
||||||
|
- GIT_OPS_URL=http://mcp-git-ops:8003
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
@@ -164,6 +200,8 @@ services:
|
|||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
mcp-knowledge-base:
|
mcp-knowledge-base:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
mcp-git-ops:
|
||||||
|
condition: service_healthy
|
||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]
|
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]
|
||||||
@@ -181,11 +219,14 @@ services:
|
|||||||
- DATABASE_URL=${DATABASE_URL}
|
- DATABASE_URL=${DATABASE_URL}
|
||||||
- REDIS_URL=redis://redis:6379/0
|
- REDIS_URL=redis://redis:6379/0
|
||||||
- CELERY_QUEUE=git
|
- CELERY_QUEUE=git
|
||||||
|
- GIT_OPS_URL=http://mcp-git-ops:8003
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
mcp-git-ops:
|
||||||
|
condition: service_healthy
|
||||||
networks:
|
networks:
|
||||||
- app-network
|
- app-network
|
||||||
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "git", "-l", "info", "-c", "2"]
|
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "git", "-l", "info", "-c", "2"]
|
||||||
@@ -260,6 +301,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
postgres_data_dev:
|
postgres_data_dev:
|
||||||
redis_data_dev:
|
redis_data_dev:
|
||||||
|
git_workspaces_dev:
|
||||||
frontend_dev_modules:
|
frontend_dev_modules:
|
||||||
frontend_dev_next:
|
frontend_dev_next:
|
||||||
|
|
||||||
|
|||||||
67
mcp-servers/git-ops/Dockerfile
Normal file
67
mcp-servers/git-ops/Dockerfile
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
# Git Operations MCP Server Dockerfile
|
||||||
|
# Multi-stage build for smaller production image
|
||||||
|
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
# Install build dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
git \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install uv for fast package management
|
||||||
|
RUN pip install --no-cache-dir uv
|
||||||
|
|
||||||
|
# Create app directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy dependency files
|
||||||
|
COPY pyproject.toml .
|
||||||
|
|
||||||
|
# Install dependencies with uv
|
||||||
|
RUN uv pip install --system --no-cache .
|
||||||
|
|
||||||
|
# Production stage
|
||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
# Install runtime dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
git \
|
||||||
|
openssh-client \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN useradd --create-home --shell /bin/bash syndarix
|
||||||
|
|
||||||
|
# Create workspace directory
|
||||||
|
RUN mkdir -p /var/syndarix/workspaces && chown -R syndarix:syndarix /var/syndarix
|
||||||
|
|
||||||
|
# Create app directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy installed packages from builder
|
||||||
|
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
||||||
|
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY --chown=syndarix:syndarix . .
|
||||||
|
|
||||||
|
# Set Python path
|
||||||
|
ENV PYTHONPATH=/app
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
# Configure git for the container
|
||||||
|
RUN git config --global --add safe.directory '*'
|
||||||
|
|
||||||
|
# Switch to non-root user
|
||||||
|
USER syndarix
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE 8003
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||||
|
CMD python -c "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()" || exit 1
|
||||||
|
|
||||||
|
# Run the server
|
||||||
|
CMD ["python", "server.py"]
|
||||||
88
mcp-servers/git-ops/Makefile
Normal file
88
mcp-servers/git-ops/Makefile
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
.PHONY: help install install-dev lint lint-fix format format-check type-check test test-cov validate clean run
|
||||||
|
|
||||||
|
# Ensure commands in this project don't inherit an external Python virtualenv
|
||||||
|
# (prevents uv warnings about mismatched VIRTUAL_ENV when running from repo root)
|
||||||
|
unexport VIRTUAL_ENV
|
||||||
|
|
||||||
|
# Default target
|
||||||
|
help:
|
||||||
|
@echo "Git Operations MCP Server - Development Commands"
|
||||||
|
@echo ""
|
||||||
|
@echo "Setup:"
|
||||||
|
@echo " make install - Install production dependencies"
|
||||||
|
@echo " make install-dev - Install development dependencies"
|
||||||
|
@echo ""
|
||||||
|
@echo "Quality Checks:"
|
||||||
|
@echo " make lint - Run Ruff linter"
|
||||||
|
@echo " make lint-fix - Run Ruff linter with auto-fix"
|
||||||
|
@echo " make format - Format code with Ruff"
|
||||||
|
@echo " make format-check - Check if code is formatted"
|
||||||
|
@echo " make type-check - Run mypy type checker"
|
||||||
|
@echo ""
|
||||||
|
@echo "Testing:"
|
||||||
|
@echo " make test - Run pytest"
|
||||||
|
@echo " make test-cov - Run pytest with coverage"
|
||||||
|
@echo ""
|
||||||
|
@echo "All-in-one:"
|
||||||
|
@echo " make validate - Run all checks (lint + format + types)"
|
||||||
|
@echo ""
|
||||||
|
@echo "Running:"
|
||||||
|
@echo " make run - Run the server locally"
|
||||||
|
@echo ""
|
||||||
|
@echo "Cleanup:"
|
||||||
|
@echo " make clean - Remove cache and build artifacts"
|
||||||
|
|
||||||
|
# Setup
|
||||||
|
install:
|
||||||
|
@echo "Installing production dependencies..."
|
||||||
|
@uv pip install -e .
|
||||||
|
|
||||||
|
install-dev:
|
||||||
|
@echo "Installing development dependencies..."
|
||||||
|
@uv pip install -e ".[dev]"
|
||||||
|
|
||||||
|
# Quality checks
|
||||||
|
lint:
|
||||||
|
@echo "Running Ruff linter..."
|
||||||
|
@uv run ruff check .
|
||||||
|
|
||||||
|
lint-fix:
|
||||||
|
@echo "Running Ruff linter with auto-fix..."
|
||||||
|
@uv run ruff check --fix .
|
||||||
|
|
||||||
|
format:
|
||||||
|
@echo "Formatting code..."
|
||||||
|
@uv run ruff format .
|
||||||
|
|
||||||
|
format-check:
|
||||||
|
@echo "Checking code formatting..."
|
||||||
|
@uv run ruff format --check .
|
||||||
|
|
||||||
|
type-check:
|
||||||
|
@echo "Running mypy..."
|
||||||
|
@uv run python -m mypy server.py config.py models.py exceptions.py git_wrapper.py workspace.py providers/ --explicit-package-bases
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
test:
|
||||||
|
@echo "Running tests..."
|
||||||
|
@IS_TEST=True uv run pytest tests/ -v
|
||||||
|
|
||||||
|
test-cov:
|
||||||
|
@echo "Running tests with coverage..."
|
||||||
|
@IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
|
||||||
|
|
||||||
|
# All-in-one validation
|
||||||
|
validate: lint format-check type-check
|
||||||
|
@echo "All validations passed!"
|
||||||
|
|
||||||
|
# Running
|
||||||
|
run:
|
||||||
|
@echo "Starting Git Operations server..."
|
||||||
|
@uv run python server.py
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
clean:
|
||||||
|
@echo "Cleaning up..."
|
||||||
|
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
|
||||||
|
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||||
|
@find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||||
179
mcp-servers/git-ops/__init__.py
Normal file
179
mcp-servers/git-ops/__init__.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""
|
||||||
|
Git Operations MCP Server.
|
||||||
|
|
||||||
|
Provides git repository management, branching, commits, and PR workflows
|
||||||
|
for Syndarix AI agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
from config import Settings, get_settings, is_test_mode, reset_settings
|
||||||
|
from exceptions import (
|
||||||
|
APIError,
|
||||||
|
AuthenticationError,
|
||||||
|
BranchExistsError,
|
||||||
|
BranchNotFoundError,
|
||||||
|
CheckoutError,
|
||||||
|
CloneError,
|
||||||
|
CommitError,
|
||||||
|
CredentialError,
|
||||||
|
CredentialNotFoundError,
|
||||||
|
DirtyWorkspaceError,
|
||||||
|
ErrorCode,
|
||||||
|
GitError,
|
||||||
|
GitOpsError,
|
||||||
|
InvalidRefError,
|
||||||
|
MergeConflictError,
|
||||||
|
PRError,
|
||||||
|
PRNotFoundError,
|
||||||
|
ProviderError,
|
||||||
|
ProviderNotFoundError,
|
||||||
|
PullError,
|
||||||
|
PushError,
|
||||||
|
WorkspaceError,
|
||||||
|
WorkspaceLockedError,
|
||||||
|
WorkspaceNotFoundError,
|
||||||
|
WorkspaceSizeExceededError,
|
||||||
|
)
|
||||||
|
from models import (
|
||||||
|
BranchInfo,
|
||||||
|
BranchRequest,
|
||||||
|
BranchResult,
|
||||||
|
CheckoutRequest,
|
||||||
|
CheckoutResult,
|
||||||
|
CloneRequest,
|
||||||
|
CloneResult,
|
||||||
|
CommitInfo,
|
||||||
|
CommitRequest,
|
||||||
|
CommitResult,
|
||||||
|
CreatePRRequest,
|
||||||
|
CreatePRResult,
|
||||||
|
DiffHunk,
|
||||||
|
DiffRequest,
|
||||||
|
DiffResult,
|
||||||
|
FileChange,
|
||||||
|
FileChangeType,
|
||||||
|
FileDiff,
|
||||||
|
GetPRRequest,
|
||||||
|
GetPRResult,
|
||||||
|
GetWorkspaceRequest,
|
||||||
|
GetWorkspaceResult,
|
||||||
|
HealthStatus,
|
||||||
|
ListBranchesRequest,
|
||||||
|
ListBranchesResult,
|
||||||
|
ListPRsRequest,
|
||||||
|
ListPRsResult,
|
||||||
|
LockWorkspaceRequest,
|
||||||
|
LockWorkspaceResult,
|
||||||
|
LogRequest,
|
||||||
|
LogResult,
|
||||||
|
MergePRRequest,
|
||||||
|
MergePRResult,
|
||||||
|
MergeStrategy,
|
||||||
|
PRInfo,
|
||||||
|
ProviderStatus,
|
||||||
|
ProviderType,
|
||||||
|
PRState,
|
||||||
|
PullRequest,
|
||||||
|
PullResult,
|
||||||
|
PushRequest,
|
||||||
|
PushResult,
|
||||||
|
StatusRequest,
|
||||||
|
StatusResult,
|
||||||
|
UnlockWorkspaceRequest,
|
||||||
|
UnlockWorkspaceResult,
|
||||||
|
UpdatePRRequest,
|
||||||
|
UpdatePRResult,
|
||||||
|
WorkspaceInfo,
|
||||||
|
WorkspaceState,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Version
|
||||||
|
"__version__",
|
||||||
|
# Config
|
||||||
|
"Settings",
|
||||||
|
"get_settings",
|
||||||
|
"reset_settings",
|
||||||
|
"is_test_mode",
|
||||||
|
# Error codes
|
||||||
|
"ErrorCode",
|
||||||
|
# Exceptions
|
||||||
|
"GitOpsError",
|
||||||
|
"WorkspaceError",
|
||||||
|
"WorkspaceNotFoundError",
|
||||||
|
"WorkspaceLockedError",
|
||||||
|
"WorkspaceSizeExceededError",
|
||||||
|
"GitError",
|
||||||
|
"CloneError",
|
||||||
|
"CheckoutError",
|
||||||
|
"CommitError",
|
||||||
|
"PushError",
|
||||||
|
"PullError",
|
||||||
|
"MergeConflictError",
|
||||||
|
"BranchExistsError",
|
||||||
|
"BranchNotFoundError",
|
||||||
|
"InvalidRefError",
|
||||||
|
"DirtyWorkspaceError",
|
||||||
|
"ProviderError",
|
||||||
|
"AuthenticationError",
|
||||||
|
"ProviderNotFoundError",
|
||||||
|
"PRError",
|
||||||
|
"PRNotFoundError",
|
||||||
|
"APIError",
|
||||||
|
"CredentialError",
|
||||||
|
"CredentialNotFoundError",
|
||||||
|
# Enums
|
||||||
|
"FileChangeType",
|
||||||
|
"MergeStrategy",
|
||||||
|
"PRState",
|
||||||
|
"ProviderType",
|
||||||
|
"WorkspaceState",
|
||||||
|
# Dataclasses
|
||||||
|
"FileChange",
|
||||||
|
"BranchInfo",
|
||||||
|
"CommitInfo",
|
||||||
|
"DiffHunk",
|
||||||
|
"FileDiff",
|
||||||
|
"PRInfo",
|
||||||
|
"WorkspaceInfo",
|
||||||
|
# Request/Response models
|
||||||
|
"CloneRequest",
|
||||||
|
"CloneResult",
|
||||||
|
"StatusRequest",
|
||||||
|
"StatusResult",
|
||||||
|
"BranchRequest",
|
||||||
|
"BranchResult",
|
||||||
|
"ListBranchesRequest",
|
||||||
|
"ListBranchesResult",
|
||||||
|
"CheckoutRequest",
|
||||||
|
"CheckoutResult",
|
||||||
|
"CommitRequest",
|
||||||
|
"CommitResult",
|
||||||
|
"PushRequest",
|
||||||
|
"PushResult",
|
||||||
|
"PullRequest",
|
||||||
|
"PullResult",
|
||||||
|
"DiffRequest",
|
||||||
|
"DiffResult",
|
||||||
|
"LogRequest",
|
||||||
|
"LogResult",
|
||||||
|
"CreatePRRequest",
|
||||||
|
"CreatePRResult",
|
||||||
|
"GetPRRequest",
|
||||||
|
"GetPRResult",
|
||||||
|
"ListPRsRequest",
|
||||||
|
"ListPRsResult",
|
||||||
|
"MergePRRequest",
|
||||||
|
"MergePRResult",
|
||||||
|
"UpdatePRRequest",
|
||||||
|
"UpdatePRResult",
|
||||||
|
"GetWorkspaceRequest",
|
||||||
|
"GetWorkspaceResult",
|
||||||
|
"LockWorkspaceRequest",
|
||||||
|
"LockWorkspaceResult",
|
||||||
|
"UnlockWorkspaceRequest",
|
||||||
|
"UnlockWorkspaceResult",
|
||||||
|
"HealthStatus",
|
||||||
|
"ProviderStatus",
|
||||||
|
]
|
||||||
155
mcp-servers/git-ops/config.py
Normal file
155
mcp-servers/git-ops/config.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""
|
||||||
|
Configuration for Git Operations MCP Server.
|
||||||
|
|
||||||
|
Uses pydantic-settings for environment variable loading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Application settings loaded from environment."""
|
||||||
|
|
||||||
|
# Server settings
|
||||||
|
host: str = Field(default="0.0.0.0", description="Server host")
|
||||||
|
port: int = Field(default=8003, description="Server port")
|
||||||
|
debug: bool = Field(default=False, description="Debug mode")
|
||||||
|
|
||||||
|
# Workspace settings
|
||||||
|
workspace_base_path: Path = Field(
|
||||||
|
default=Path("/var/syndarix/workspaces"),
|
||||||
|
description="Base path for git workspaces",
|
||||||
|
)
|
||||||
|
workspace_max_size_gb: float = Field(
|
||||||
|
default=10.0,
|
||||||
|
description="Maximum size per workspace in GB",
|
||||||
|
)
|
||||||
|
workspace_stale_days: int = Field(
|
||||||
|
default=7,
|
||||||
|
description="Days after which unused workspace is considered stale",
|
||||||
|
)
|
||||||
|
workspace_lock_timeout: int = Field(
|
||||||
|
default=300,
|
||||||
|
description="Workspace lock timeout in seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Git settings
|
||||||
|
git_timeout: int = Field(
|
||||||
|
default=120,
|
||||||
|
description="Default timeout for git operations in seconds",
|
||||||
|
)
|
||||||
|
git_clone_timeout: int = Field(
|
||||||
|
default=600,
|
||||||
|
description="Timeout for clone operations in seconds",
|
||||||
|
)
|
||||||
|
git_author_name: str = Field(
|
||||||
|
default="Syndarix Agent",
|
||||||
|
description="Default author name for commits",
|
||||||
|
)
|
||||||
|
git_author_email: str = Field(
|
||||||
|
default="agent@syndarix.ai",
|
||||||
|
description="Default author email for commits",
|
||||||
|
)
|
||||||
|
git_max_diff_lines: int = Field(
|
||||||
|
default=10000,
|
||||||
|
description="Maximum lines in diff output",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Redis settings (for distributed locking)
|
||||||
|
redis_url: str = Field(
|
||||||
|
default="redis://localhost:6379/0",
|
||||||
|
description="Redis connection URL",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Provider settings
|
||||||
|
gitea_base_url: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Gitea API base URL (e.g., https://gitea.example.com)",
|
||||||
|
)
|
||||||
|
gitea_token: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Gitea API token",
|
||||||
|
)
|
||||||
|
github_token: str = Field(
|
||||||
|
default="",
|
||||||
|
description="GitHub API token",
|
||||||
|
)
|
||||||
|
github_api_url: str = Field(
|
||||||
|
default="https://api.github.com",
|
||||||
|
description="GitHub API URL (for Enterprise)",
|
||||||
|
)
|
||||||
|
gitlab_token: str = Field(
|
||||||
|
default="",
|
||||||
|
description="GitLab API token",
|
||||||
|
)
|
||||||
|
gitlab_url: str = Field(
|
||||||
|
default="https://gitlab.com",
|
||||||
|
description="GitLab URL (for self-hosted)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rate limiting
|
||||||
|
rate_limit_requests: int = Field(
|
||||||
|
default=100,
|
||||||
|
description="Max API requests per minute per provider",
|
||||||
|
)
|
||||||
|
rate_limit_window: int = Field(
|
||||||
|
default=60,
|
||||||
|
description="Rate limit window in seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry settings
|
||||||
|
retry_attempts: int = Field(
|
||||||
|
default=3,
|
||||||
|
description="Number of retry attempts for failed operations",
|
||||||
|
)
|
||||||
|
retry_delay: float = Field(
|
||||||
|
default=1.0,
|
||||||
|
description="Initial retry delay in seconds",
|
||||||
|
)
|
||||||
|
retry_max_delay: float = Field(
|
||||||
|
default=30.0,
|
||||||
|
description="Maximum retry delay in seconds",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Security settings
|
||||||
|
allowed_hosts: list[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Allowed git host domains (empty = all)",
|
||||||
|
)
|
||||||
|
max_clone_size_mb: int = Field(
|
||||||
|
default=500,
|
||||||
|
description="Maximum repository size for clone in MB",
|
||||||
|
)
|
||||||
|
enable_force_push: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Allow force push operations",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = {"env_prefix": "GIT_OPS_", "env_file": ".env", "extra": "ignore"}
|
||||||
|
|
||||||
|
|
||||||
|
# Global settings instance (lazy initialization)
|
||||||
|
_settings: Settings | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""Get the global settings instance."""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
_settings = Settings()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def reset_settings() -> None:
|
||||||
|
"""Reset the global settings (for testing)."""
|
||||||
|
global _settings
|
||||||
|
_settings = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_test_mode() -> bool:
|
||||||
|
"""Check if running in test mode."""
|
||||||
|
return os.getenv("IS_TEST", "").lower() in ("true", "1", "yes")
|
||||||
359
mcp-servers/git-ops/exceptions.py
Normal file
359
mcp-servers/git-ops/exceptions.py
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
"""
|
||||||
|
Exception hierarchy for Git Operations MCP Server.
|
||||||
|
|
||||||
|
Provides structured error handling with error codes for MCP responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorCode(str, Enum):
|
||||||
|
"""Error codes for Git Operations errors."""
|
||||||
|
|
||||||
|
# General errors (1xxx)
|
||||||
|
INTERNAL_ERROR = "GIT_1000"
|
||||||
|
INVALID_REQUEST = "GIT_1001"
|
||||||
|
NOT_FOUND = "GIT_1002"
|
||||||
|
PERMISSION_DENIED = "GIT_1003"
|
||||||
|
TIMEOUT = "GIT_1004"
|
||||||
|
RATE_LIMITED = "GIT_1005"
|
||||||
|
|
||||||
|
# Workspace errors (2xxx)
|
||||||
|
WORKSPACE_NOT_FOUND = "GIT_2000"
|
||||||
|
WORKSPACE_LOCKED = "GIT_2001"
|
||||||
|
WORKSPACE_SIZE_EXCEEDED = "GIT_2002"
|
||||||
|
WORKSPACE_CREATE_FAILED = "GIT_2003"
|
||||||
|
WORKSPACE_DELETE_FAILED = "GIT_2004"
|
||||||
|
|
||||||
|
# Git operation errors (3xxx)
|
||||||
|
CLONE_FAILED = "GIT_3000"
|
||||||
|
CHECKOUT_FAILED = "GIT_3001"
|
||||||
|
COMMIT_FAILED = "GIT_3002"
|
||||||
|
PUSH_FAILED = "GIT_3003"
|
||||||
|
PULL_FAILED = "GIT_3004"
|
||||||
|
MERGE_CONFLICT = "GIT_3005"
|
||||||
|
BRANCH_EXISTS = "GIT_3006"
|
||||||
|
BRANCH_NOT_FOUND = "GIT_3007"
|
||||||
|
INVALID_REF = "GIT_3008"
|
||||||
|
DIRTY_WORKSPACE = "GIT_3009"
|
||||||
|
UNCOMMITTED_CHANGES = "GIT_3010"
|
||||||
|
FETCH_FAILED = "GIT_3011"
|
||||||
|
RESET_FAILED = "GIT_3012"
|
||||||
|
|
||||||
|
# Provider errors (4xxx)
|
||||||
|
PROVIDER_ERROR = "GIT_4000"
|
||||||
|
PROVIDER_AUTH_FAILED = "GIT_4001"
|
||||||
|
PROVIDER_NOT_FOUND = "GIT_4002"
|
||||||
|
PR_CREATE_FAILED = "GIT_4003"
|
||||||
|
PR_MERGE_FAILED = "GIT_4004"
|
||||||
|
PR_NOT_FOUND = "GIT_4005"
|
||||||
|
API_ERROR = "GIT_4006"
|
||||||
|
|
||||||
|
# Credential errors (5xxx)
|
||||||
|
CREDENTIAL_ERROR = "GIT_5000"
|
||||||
|
CREDENTIAL_NOT_FOUND = "GIT_5001"
|
||||||
|
CREDENTIAL_INVALID = "GIT_5002"
|
||||||
|
SSH_KEY_ERROR = "GIT_5003"
|
||||||
|
|
||||||
|
|
||||||
|
class GitOpsError(Exception):
|
||||||
|
"""Base exception for Git Operations errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.INTERNAL_ERROR,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.code = code
|
||||||
|
self.details = details or {}
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary for MCP response."""
|
||||||
|
result: dict[str, Any] = {
|
||||||
|
"error": self.message,
|
||||||
|
"code": self.code.value,
|
||||||
|
}
|
||||||
|
if self.details:
|
||||||
|
result["details"] = self.details
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# Workspace Errors
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceError(GitOpsError):
|
||||||
|
"""Base exception for workspace-related errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.WORKSPACE_NOT_FOUND,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, code, details)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceNotFoundError(WorkspaceError):
|
||||||
|
"""Workspace does not exist."""
|
||||||
|
|
||||||
|
def __init__(self, project_id: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Workspace not found for project: {project_id}",
|
||||||
|
ErrorCode.WORKSPACE_NOT_FOUND,
|
||||||
|
{"project_id": project_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceLockedError(WorkspaceError):
|
||||||
|
"""Workspace is locked by another operation."""
|
||||||
|
|
||||||
|
def __init__(self, project_id: str, holder: str | None = None) -> None:
|
||||||
|
details: dict[str, Any] = {"project_id": project_id}
|
||||||
|
if holder:
|
||||||
|
details["locked_by"] = holder
|
||||||
|
super().__init__(
|
||||||
|
f"Workspace is locked for project: {project_id}",
|
||||||
|
ErrorCode.WORKSPACE_LOCKED,
|
||||||
|
details,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceSizeExceededError(WorkspaceError):
|
||||||
|
"""Workspace size limit exceeded."""
|
||||||
|
|
||||||
|
def __init__(self, project_id: str, current_size: float, max_size: float) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Workspace size limit exceeded for project: {project_id}",
|
||||||
|
ErrorCode.WORKSPACE_SIZE_EXCEEDED,
|
||||||
|
{
|
||||||
|
"project_id": project_id,
|
||||||
|
"current_size_gb": current_size,
|
||||||
|
"max_size_gb": max_size,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Git Operation Errors
|
||||||
|
|
||||||
|
|
||||||
|
class GitError(GitOpsError):
|
||||||
|
"""Base exception for git operation errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.INTERNAL_ERROR,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, code, details)
|
||||||
|
|
||||||
|
|
||||||
|
class CloneError(GitError):
|
||||||
|
"""Failed to clone repository."""
|
||||||
|
|
||||||
|
def __init__(self, repo_url: str, reason: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Failed to clone repository: {reason}",
|
||||||
|
ErrorCode.CLONE_FAILED,
|
||||||
|
{"repo_url": repo_url, "reason": reason},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckoutError(GitError):
|
||||||
|
"""Failed to checkout branch or ref."""
|
||||||
|
|
||||||
|
def __init__(self, ref: str, reason: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Failed to checkout '{ref}': {reason}",
|
||||||
|
ErrorCode.CHECKOUT_FAILED,
|
||||||
|
{"ref": ref, "reason": reason},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CommitError(GitError):
|
||||||
|
"""Failed to commit changes."""
|
||||||
|
|
||||||
|
def __init__(self, reason: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Failed to commit: {reason}",
|
||||||
|
ErrorCode.COMMIT_FAILED,
|
||||||
|
{"reason": reason},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PushError(GitError):
|
||||||
|
"""Failed to push to remote."""
|
||||||
|
|
||||||
|
def __init__(self, branch: str, reason: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Failed to push branch '{branch}': {reason}",
|
||||||
|
ErrorCode.PUSH_FAILED,
|
||||||
|
{"branch": branch, "reason": reason},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PullError(GitError):
|
||||||
|
"""Failed to pull from remote."""
|
||||||
|
|
||||||
|
def __init__(self, branch: str, reason: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Failed to pull branch '{branch}': {reason}",
|
||||||
|
ErrorCode.PULL_FAILED,
|
||||||
|
{"branch": branch, "reason": reason},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MergeConflictError(GitError):
|
||||||
|
"""Merge conflict detected."""
|
||||||
|
|
||||||
|
def __init__(self, conflicting_files: list[str]) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Merge conflict detected in {len(conflicting_files)} files",
|
||||||
|
ErrorCode.MERGE_CONFLICT,
|
||||||
|
{"conflicting_files": conflicting_files},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BranchExistsError(GitError):
|
||||||
|
"""Branch already exists."""
|
||||||
|
|
||||||
|
def __init__(self, branch_name: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Branch already exists: {branch_name}",
|
||||||
|
ErrorCode.BRANCH_EXISTS,
|
||||||
|
{"branch": branch_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BranchNotFoundError(GitError):
|
||||||
|
"""Branch does not exist."""
|
||||||
|
|
||||||
|
def __init__(self, branch_name: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Branch not found: {branch_name}",
|
||||||
|
ErrorCode.BRANCH_NOT_FOUND,
|
||||||
|
{"branch": branch_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRefError(GitError):
|
||||||
|
"""Invalid git reference."""
|
||||||
|
|
||||||
|
def __init__(self, ref: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Invalid git reference: {ref}",
|
||||||
|
ErrorCode.INVALID_REF,
|
||||||
|
{"ref": ref},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DirtyWorkspaceError(GitError):
|
||||||
|
"""Workspace has uncommitted changes."""
|
||||||
|
|
||||||
|
def __init__(self, modified_files: list[str]) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Workspace has {len(modified_files)} uncommitted changes",
|
||||||
|
ErrorCode.DIRTY_WORKSPACE,
|
||||||
|
{"modified_files": modified_files[:10]}, # Limit to first 10
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Provider Errors
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderError(GitOpsError):
|
||||||
|
"""Base exception for provider-related errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.PROVIDER_ERROR,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, code, details)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(ProviderError):
|
||||||
|
"""Authentication with provider failed."""
|
||||||
|
|
||||||
|
def __init__(self, provider: str, reason: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Authentication failed with {provider}: {reason}",
|
||||||
|
ErrorCode.PROVIDER_AUTH_FAILED,
|
||||||
|
{"provider": provider, "reason": reason},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderNotFoundError(ProviderError):
|
||||||
|
"""Provider not configured or recognized."""
|
||||||
|
|
||||||
|
def __init__(self, provider: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Provider not found or not configured: {provider}",
|
||||||
|
ErrorCode.PROVIDER_NOT_FOUND,
|
||||||
|
{"provider": provider},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PRError(ProviderError):
|
||||||
|
"""Pull request operation failed."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.PR_CREATE_FAILED,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, code, details)
|
||||||
|
|
||||||
|
|
||||||
|
class PRNotFoundError(PRError):
|
||||||
|
"""Pull request not found."""
|
||||||
|
|
||||||
|
def __init__(self, pr_number: int, repo: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"Pull request #{pr_number} not found in {repo}",
|
||||||
|
ErrorCode.PR_NOT_FOUND,
|
||||||
|
{"pr_number": pr_number, "repo": repo},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class APIError(ProviderError):
|
||||||
|
"""Provider API error."""
|
||||||
|
|
||||||
|
def __init__(self, provider: str, status_code: int, message: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"{provider} API error ({status_code}): {message}",
|
||||||
|
ErrorCode.API_ERROR,
|
||||||
|
{"provider": provider, "status_code": status_code, "message": message},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Credential Errors
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialError(GitOpsError):
|
||||||
|
"""Base exception for credential-related errors."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str,
|
||||||
|
code: ErrorCode = ErrorCode.CREDENTIAL_ERROR,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message, code, details)
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialNotFoundError(CredentialError):
|
||||||
|
"""Credential not found."""
|
||||||
|
|
||||||
|
def __init__(self, credential_type: str, identifier: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
f"{credential_type} credential not found: {identifier}",
|
||||||
|
ErrorCode.CREDENTIAL_NOT_FOUND,
|
||||||
|
{"type": credential_type, "identifier": identifier},
|
||||||
|
)
|
||||||
1170
mcp-servers/git-ops/git_wrapper.py
Normal file
1170
mcp-servers/git-ops/git_wrapper.py
Normal file
File diff suppressed because it is too large
Load Diff
690
mcp-servers/git-ops/models.py
Normal file
690
mcp-servers/git-ops/models.py
Normal file
@@ -0,0 +1,690 @@
|
|||||||
|
"""
|
||||||
|
Data models for Git Operations MCP Server.
|
||||||
|
|
||||||
|
Defines data structures for git operations, workspace management,
|
||||||
|
and provider interactions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class FileChangeType(str, Enum):
|
||||||
|
"""Types of file changes in git."""
|
||||||
|
|
||||||
|
ADDED = "added"
|
||||||
|
MODIFIED = "modified"
|
||||||
|
DELETED = "deleted"
|
||||||
|
RENAMED = "renamed"
|
||||||
|
COPIED = "copied"
|
||||||
|
UNTRACKED = "untracked"
|
||||||
|
IGNORED = "ignored"
|
||||||
|
|
||||||
|
|
||||||
|
class MergeStrategy(str, Enum):
|
||||||
|
"""Merge strategies for pull requests."""
|
||||||
|
|
||||||
|
MERGE = "merge" # Create a merge commit
|
||||||
|
SQUASH = "squash" # Squash and merge
|
||||||
|
REBASE = "rebase" # Rebase and merge
|
||||||
|
|
||||||
|
|
||||||
|
class PRState(str, Enum):
|
||||||
|
"""Pull request states."""
|
||||||
|
|
||||||
|
OPEN = "open"
|
||||||
|
CLOSED = "closed"
|
||||||
|
MERGED = "merged"
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderType(str, Enum):
|
||||||
|
"""Supported git providers."""
|
||||||
|
|
||||||
|
GITEA = "gitea"
|
||||||
|
GITHUB = "github"
|
||||||
|
GITLAB = "gitlab"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceState(str, Enum):
|
||||||
|
"""Workspace lifecycle states."""
|
||||||
|
|
||||||
|
INITIALIZING = "initializing"
|
||||||
|
READY = "ready"
|
||||||
|
LOCKED = "locked"
|
||||||
|
STALE = "stale"
|
||||||
|
DELETED = "deleted"
|
||||||
|
|
||||||
|
|
||||||
|
# Dataclasses for internal data structures
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileChange:
|
||||||
|
"""A file change in git status."""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
change_type: FileChangeType
|
||||||
|
old_path: str | None = None # For renames
|
||||||
|
additions: int = 0
|
||||||
|
deletions: int = 0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"path": self.path,
|
||||||
|
"change_type": self.change_type.value,
|
||||||
|
"old_path": self.old_path,
|
||||||
|
"additions": self.additions,
|
||||||
|
"deletions": self.deletions,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BranchInfo:
|
||||||
|
"""Information about a git branch."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
is_current: bool = False
|
||||||
|
is_remote: bool = False
|
||||||
|
tracking_branch: str | None = None
|
||||||
|
commit_sha: str | None = None
|
||||||
|
commit_message: str | None = None
|
||||||
|
ahead: int = 0
|
||||||
|
behind: int = 0
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"is_current": self.is_current,
|
||||||
|
"is_remote": self.is_remote,
|
||||||
|
"tracking_branch": self.tracking_branch,
|
||||||
|
"commit_sha": self.commit_sha,
|
||||||
|
"commit_message": self.commit_message,
|
||||||
|
"ahead": self.ahead,
|
||||||
|
"behind": self.behind,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommitInfo:
|
||||||
|
"""Information about a git commit."""
|
||||||
|
|
||||||
|
sha: str
|
||||||
|
short_sha: str
|
||||||
|
message: str
|
||||||
|
author_name: str
|
||||||
|
author_email: str
|
||||||
|
authored_date: datetime
|
||||||
|
committer_name: str
|
||||||
|
committer_email: str
|
||||||
|
committed_date: datetime
|
||||||
|
parents: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"sha": self.sha,
|
||||||
|
"short_sha": self.short_sha,
|
||||||
|
"message": self.message,
|
||||||
|
"author_name": self.author_name,
|
||||||
|
"author_email": self.author_email,
|
||||||
|
"authored_date": self.authored_date.isoformat(),
|
||||||
|
"committer_name": self.committer_name,
|
||||||
|
"committer_email": self.committer_email,
|
||||||
|
"committed_date": self.committed_date.isoformat(),
|
||||||
|
"parents": self.parents,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiffHunk:
|
||||||
|
"""A hunk of diff content."""
|
||||||
|
|
||||||
|
old_start: int
|
||||||
|
old_lines: int
|
||||||
|
new_start: int
|
||||||
|
new_lines: int
|
||||||
|
content: str
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"old_start": self.old_start,
|
||||||
|
"old_lines": self.old_lines,
|
||||||
|
"new_start": self.new_start,
|
||||||
|
"new_lines": self.new_lines,
|
||||||
|
"content": self.content,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FileDiff:
|
||||||
|
"""Diff for a single file."""
|
||||||
|
|
||||||
|
path: str
|
||||||
|
change_type: FileChangeType
|
||||||
|
old_path: str | None = None
|
||||||
|
hunks: list[DiffHunk] = field(default_factory=list)
|
||||||
|
additions: int = 0
|
||||||
|
deletions: int = 0
|
||||||
|
is_binary: bool = False
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"path": self.path,
|
||||||
|
"change_type": self.change_type.value,
|
||||||
|
"old_path": self.old_path,
|
||||||
|
"hunks": [h.to_dict() for h in self.hunks],
|
||||||
|
"additions": self.additions,
|
||||||
|
"deletions": self.deletions,
|
||||||
|
"is_binary": self.is_binary,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PRInfo:
|
||||||
|
"""Information about a pull request."""
|
||||||
|
|
||||||
|
number: int
|
||||||
|
title: str
|
||||||
|
body: str
|
||||||
|
state: PRState
|
||||||
|
source_branch: str
|
||||||
|
target_branch: str
|
||||||
|
author: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
merged_at: datetime | None = None
|
||||||
|
closed_at: datetime | None = None
|
||||||
|
url: str | None = None
|
||||||
|
labels: list[str] = field(default_factory=list)
|
||||||
|
assignees: list[str] = field(default_factory=list)
|
||||||
|
reviewers: list[str] = field(default_factory=list)
|
||||||
|
mergeable: bool | None = None
|
||||||
|
draft: bool = False
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"number": self.number,
|
||||||
|
"title": self.title,
|
||||||
|
"body": self.body,
|
||||||
|
"state": self.state.value,
|
||||||
|
"source_branch": self.source_branch,
|
||||||
|
"target_branch": self.target_branch,
|
||||||
|
"author": self.author,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"updated_at": self.updated_at.isoformat(),
|
||||||
|
"merged_at": self.merged_at.isoformat() if self.merged_at else None,
|
||||||
|
"closed_at": self.closed_at.isoformat() if self.closed_at else None,
|
||||||
|
"url": self.url,
|
||||||
|
"labels": self.labels,
|
||||||
|
"assignees": self.assignees,
|
||||||
|
"reviewers": self.reviewers,
|
||||||
|
"mergeable": self.mergeable,
|
||||||
|
"draft": self.draft,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkspaceInfo:
|
||||||
|
"""Information about a project workspace."""
|
||||||
|
|
||||||
|
project_id: str
|
||||||
|
path: str
|
||||||
|
state: WorkspaceState
|
||||||
|
repo_url: str | None = None
|
||||||
|
current_branch: str | None = None
|
||||||
|
last_accessed: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
size_bytes: int = 0
|
||||||
|
lock_holder: str | None = None
|
||||||
|
lock_expires: datetime | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"path": self.path,
|
||||||
|
"state": self.state.value,
|
||||||
|
"repo_url": self.repo_url,
|
||||||
|
"current_branch": self.current_branch,
|
||||||
|
"last_accessed": self.last_accessed.isoformat(),
|
||||||
|
"size_bytes": self.size_bytes,
|
||||||
|
"lock_holder": self.lock_holder,
|
||||||
|
"lock_expires": self.lock_expires.isoformat()
|
||||||
|
if self.lock_expires
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Pydantic Request/Response Models
|
||||||
|
|
||||||
|
|
||||||
|
class CloneRequest(BaseModel):
|
||||||
|
"""Request to clone a repository."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
repo_url: str = Field(..., description="Repository URL to clone")
|
||||||
|
branch: str | None = Field(
|
||||||
|
default=None, description="Branch to checkout after clone"
|
||||||
|
)
|
||||||
|
depth: int | None = Field(
|
||||||
|
default=None, ge=1, description="Shallow clone depth (None = full clone)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CloneResult(BaseModel):
|
||||||
|
"""Result of a clone operation."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether clone succeeded")
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
workspace_path: str = Field(..., description="Path to cloned workspace")
|
||||||
|
branch: str = Field(..., description="Current branch after clone")
|
||||||
|
commit_sha: str = Field(..., description="HEAD commit SHA")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class StatusRequest(BaseModel):
|
||||||
|
"""Request for git status."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
include_untracked: bool = Field(default=True, description="Include untracked files")
|
||||||
|
|
||||||
|
|
||||||
|
class StatusResult(BaseModel):
|
||||||
|
"""Result of a status operation."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
branch: str = Field(..., description="Current branch")
|
||||||
|
commit_sha: str = Field(..., description="HEAD commit SHA")
|
||||||
|
is_clean: bool = Field(..., description="Whether working tree is clean")
|
||||||
|
staged: list[dict[str, Any]] = Field(
|
||||||
|
default_factory=list, description="Staged changes"
|
||||||
|
)
|
||||||
|
unstaged: list[dict[str, Any]] = Field(
|
||||||
|
default_factory=list, description="Unstaged changes"
|
||||||
|
)
|
||||||
|
untracked: list[str] = Field(default_factory=list, description="Untracked files")
|
||||||
|
ahead: int = Field(default=0, description="Commits ahead of upstream")
|
||||||
|
behind: int = Field(default=0, description="Commits behind upstream")
|
||||||
|
|
||||||
|
|
||||||
|
class BranchRequest(BaseModel):
|
||||||
|
"""Request for branch operations."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
branch_name: str = Field(..., description="Branch name")
|
||||||
|
from_ref: str | None = Field(
|
||||||
|
default=None, description="Reference to create branch from"
|
||||||
|
)
|
||||||
|
checkout: bool = Field(default=True, description="Checkout after creation")
|
||||||
|
|
||||||
|
|
||||||
|
class BranchResult(BaseModel):
|
||||||
|
"""Result of a branch operation."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether operation succeeded")
|
||||||
|
branch: str = Field(..., description="Branch name")
|
||||||
|
commit_sha: str | None = Field(default=None, description="HEAD commit SHA")
|
||||||
|
is_current: bool = Field(default=False, description="Whether branch is checked out")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class ListBranchesRequest(BaseModel):
|
||||||
|
"""Request to list branches."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
include_remote: bool = Field(default=False, description="Include remote branches")
|
||||||
|
|
||||||
|
|
||||||
|
class ListBranchesResult(BaseModel):
|
||||||
|
"""Result of listing branches."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
current_branch: str = Field(..., description="Currently checked out branch")
|
||||||
|
local_branches: list[dict[str, Any]] = Field(
|
||||||
|
default_factory=list, description="Local branches"
|
||||||
|
)
|
||||||
|
remote_branches: list[dict[str, Any]] = Field(
|
||||||
|
default_factory=list, description="Remote branches"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CheckoutRequest(BaseModel):
|
||||||
|
"""Request to checkout a branch or ref."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
ref: str = Field(..., description="Branch, tag, or commit to checkout")
|
||||||
|
create_branch: bool = Field(default=False, description="Create new branch")
|
||||||
|
force: bool = Field(default=False, description="Force checkout (discard changes)")
|
||||||
|
|
||||||
|
|
||||||
|
class CheckoutResult(BaseModel):
|
||||||
|
"""Result of a checkout operation."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether checkout succeeded")
|
||||||
|
ref: str = Field(..., description="Checked out reference")
|
||||||
|
commit_sha: str | None = Field(default=None, description="HEAD commit SHA")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class CommitRequest(BaseModel):
|
||||||
|
"""Request to create a commit."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
message: str = Field(..., description="Commit message")
|
||||||
|
files: list[str] | None = Field(
|
||||||
|
default=None, description="Files to commit (None = all staged)"
|
||||||
|
)
|
||||||
|
author_name: str | None = Field(default=None, description="Author name override")
|
||||||
|
author_email: str | None = Field(default=None, description="Author email override")
|
||||||
|
allow_empty: bool = Field(default=False, description="Allow empty commit")
|
||||||
|
|
||||||
|
|
||||||
|
class CommitResult(BaseModel):
|
||||||
|
"""Result of a commit operation."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether commit succeeded")
|
||||||
|
commit_sha: str | None = Field(default=None, description="New commit SHA")
|
||||||
|
short_sha: str | None = Field(default=None, description="Short commit SHA")
|
||||||
|
message: str | None = Field(default=None, description="Commit message")
|
||||||
|
files_changed: int = Field(default=0, description="Number of files changed")
|
||||||
|
insertions: int = Field(default=0, description="Lines added")
|
||||||
|
deletions: int = Field(default=0, description="Lines removed")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class PushRequest(BaseModel):
|
||||||
|
"""Request to push to remote."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
branch: str | None = Field(
|
||||||
|
default=None, description="Branch to push (None = current)"
|
||||||
|
)
|
||||||
|
remote: str = Field(default="origin", description="Remote name")
|
||||||
|
force: bool = Field(default=False, description="Force push")
|
||||||
|
set_upstream: bool = Field(default=True, description="Set upstream tracking")
|
||||||
|
|
||||||
|
|
||||||
|
class PushResult(BaseModel):
|
||||||
|
"""Result of a push operation."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether push succeeded")
|
||||||
|
branch: str = Field(..., description="Pushed branch")
|
||||||
|
remote: str = Field(..., description="Remote name")
|
||||||
|
commits_pushed: int = Field(default=0, description="Number of commits pushed")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class PullRequest(BaseModel):
|
||||||
|
"""Request to pull from remote."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
branch: str | None = Field(
|
||||||
|
default=None, description="Branch to pull (None = current)"
|
||||||
|
)
|
||||||
|
remote: str = Field(default="origin", description="Remote name")
|
||||||
|
rebase: bool = Field(default=False, description="Rebase instead of merge")
|
||||||
|
|
||||||
|
|
||||||
|
class PullResult(BaseModel):
|
||||||
|
"""Result of a pull operation."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether pull succeeded")
|
||||||
|
branch: str = Field(..., description="Pulled branch")
|
||||||
|
commits_received: int = Field(default=0, description="New commits received")
|
||||||
|
fast_forward: bool = Field(default=False, description="Was fast-forward")
|
||||||
|
conflicts: list[str] = Field(
|
||||||
|
default_factory=list, description="Conflicting files if any"
|
||||||
|
)
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class DiffRequest(BaseModel):
|
||||||
|
"""Request for diff."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
base: str | None = Field(
|
||||||
|
default=None, description="Base reference (None = working tree)"
|
||||||
|
)
|
||||||
|
head: str | None = Field(default=None, description="Head reference (None = HEAD)")
|
||||||
|
files: list[str] | None = Field(default=None, description="Specific files to diff")
|
||||||
|
context_lines: int = Field(default=3, ge=0, description="Context lines")
|
||||||
|
|
||||||
|
|
||||||
|
class DiffResult(BaseModel):
|
||||||
|
"""Result of a diff operation."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
base: str | None = Field(default=None, description="Base reference")
|
||||||
|
head: str | None = Field(default=None, description="Head reference")
|
||||||
|
files: list[dict[str, Any]] = Field(default_factory=list, description="File diffs")
|
||||||
|
total_additions: int = Field(default=0, description="Total lines added")
|
||||||
|
total_deletions: int = Field(default=0, description="Total lines removed")
|
||||||
|
files_changed: int = Field(default=0, description="Number of files changed")
|
||||||
|
|
||||||
|
|
||||||
|
class LogRequest(BaseModel):
|
||||||
|
"""Request for commit log."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
ref: str | None = Field(default=None, description="Reference to start from")
|
||||||
|
limit: int = Field(default=20, ge=1, le=100, description="Max commits to return")
|
||||||
|
skip: int = Field(default=0, ge=0, description="Commits to skip")
|
||||||
|
path: str | None = Field(default=None, description="Filter by path")
|
||||||
|
|
||||||
|
|
||||||
|
class LogResult(BaseModel):
|
||||||
|
"""Result of a log operation."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
commits: list[dict[str, Any]] = Field(
|
||||||
|
default_factory=list, description="Commit history"
|
||||||
|
)
|
||||||
|
total_commits: int = Field(default=0, description="Total commits in range")
|
||||||
|
|
||||||
|
|
||||||
|
# PR Operations
|
||||||
|
|
||||||
|
|
||||||
|
class CreatePRRequest(BaseModel):
|
||||||
|
"""Request to create a pull request."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
title: str = Field(..., description="PR title")
|
||||||
|
body: str = Field(default="", description="PR description")
|
||||||
|
source_branch: str = Field(..., description="Source branch")
|
||||||
|
target_branch: str = Field(default="main", description="Target branch")
|
||||||
|
draft: bool = Field(default=False, description="Create as draft")
|
||||||
|
labels: list[str] = Field(default_factory=list, description="Labels to add")
|
||||||
|
assignees: list[str] = Field(default_factory=list, description="Assignees")
|
||||||
|
reviewers: list[str] = Field(default_factory=list, description="Reviewers")
|
||||||
|
|
||||||
|
|
||||||
|
class CreatePRResult(BaseModel):
|
||||||
|
"""Result of creating a pull request."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether creation succeeded")
|
||||||
|
pr_number: int | None = Field(default=None, description="PR number")
|
||||||
|
pr_url: str | None = Field(default=None, description="PR URL")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class GetPRRequest(BaseModel):
|
||||||
|
"""Request to get a pull request."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
pr_number: int = Field(..., description="PR number")
|
||||||
|
|
||||||
|
|
||||||
|
class GetPRResult(BaseModel):
|
||||||
|
"""Result of getting a pull request."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether fetch succeeded")
|
||||||
|
pr: dict[str, Any] | None = Field(default=None, description="PR info")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class ListPRsRequest(BaseModel):
|
||||||
|
"""Request to list pull requests."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
state: PRState | None = Field(default=None, description="Filter by state")
|
||||||
|
author: str | None = Field(default=None, description="Filter by author")
|
||||||
|
limit: int = Field(default=20, ge=1, le=100, description="Max PRs to return")
|
||||||
|
|
||||||
|
|
||||||
|
class ListPRsResult(BaseModel):
|
||||||
|
"""Result of listing pull requests."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether list succeeded")
|
||||||
|
pull_requests: list[dict[str, Any]] = Field(default_factory=list, description="PRs")
|
||||||
|
total_count: int = Field(default=0, description="Total matching PRs")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class MergePRRequest(BaseModel):
|
||||||
|
"""Request to merge a pull request."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
pr_number: int = Field(..., description="PR number")
|
||||||
|
merge_strategy: MergeStrategy = Field(
|
||||||
|
default=MergeStrategy.MERGE, description="Merge strategy"
|
||||||
|
)
|
||||||
|
commit_message: str | None = Field(
|
||||||
|
default=None, description="Custom merge commit message"
|
||||||
|
)
|
||||||
|
delete_branch: bool = Field(
|
||||||
|
default=True, description="Delete source branch after merge"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MergePRResult(BaseModel):
|
||||||
|
"""Result of merging a pull request."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether merge succeeded")
|
||||||
|
merge_commit_sha: str | None = Field(default=None, description="Merge commit SHA")
|
||||||
|
branch_deleted: bool = Field(
|
||||||
|
default=False, description="Whether branch was deleted"
|
||||||
|
)
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class UpdatePRRequest(BaseModel):
|
||||||
|
"""Request to update a pull request."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID for scoping")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
pr_number: int = Field(..., description="PR number")
|
||||||
|
title: str | None = Field(default=None, description="New title")
|
||||||
|
body: str | None = Field(default=None, description="New description")
|
||||||
|
state: PRState | None = Field(default=None, description="New state")
|
||||||
|
labels: list[str] | None = Field(default=None, description="Replace labels")
|
||||||
|
assignees: list[str] | None = Field(default=None, description="Replace assignees")
|
||||||
|
|
||||||
|
|
||||||
|
class UpdatePRResult(BaseModel):
|
||||||
|
"""Result of updating a pull request."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether update succeeded")
|
||||||
|
pr: dict[str, Any] | None = Field(default=None, description="Updated PR info")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
# Workspace Operations
|
||||||
|
|
||||||
|
|
||||||
|
class GetWorkspaceRequest(BaseModel):
|
||||||
|
"""Request to get or create workspace."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
agent_id: str = Field(..., description="Agent ID making the request")
|
||||||
|
|
||||||
|
|
||||||
|
class GetWorkspaceResult(BaseModel):
|
||||||
|
"""Result of getting workspace."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether operation succeeded")
|
||||||
|
workspace: dict[str, Any] | None = Field(default=None, description="Workspace info")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class LockWorkspaceRequest(BaseModel):
|
||||||
|
"""Request to lock a workspace."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
agent_id: str = Field(..., description="Agent ID requesting lock")
|
||||||
|
timeout: int = Field(
|
||||||
|
default=300, ge=10, le=3600, description="Lock timeout seconds"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LockWorkspaceResult(BaseModel):
|
||||||
|
"""Result of locking workspace."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether lock acquired")
|
||||||
|
lock_holder: str | None = Field(default=None, description="Current lock holder")
|
||||||
|
lock_expires: str | None = Field(
|
||||||
|
default=None, description="Lock expiry ISO timestamp"
|
||||||
|
)
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
class UnlockWorkspaceRequest(BaseModel):
|
||||||
|
"""Request to unlock a workspace."""
|
||||||
|
|
||||||
|
project_id: str = Field(..., description="Project ID")
|
||||||
|
agent_id: str = Field(..., description="Agent ID releasing lock")
|
||||||
|
force: bool = Field(default=False, description="Force unlock (admin only)")
|
||||||
|
|
||||||
|
|
||||||
|
class UnlockWorkspaceResult(BaseModel):
|
||||||
|
"""Result of unlocking workspace."""
|
||||||
|
|
||||||
|
success: bool = Field(..., description="Whether unlock succeeded")
|
||||||
|
error: str | None = Field(default=None, description="Error message if failed")
|
||||||
|
|
||||||
|
|
||||||
|
# Health and Status
|
||||||
|
|
||||||
|
|
||||||
|
class HealthStatus(BaseModel):
|
||||||
|
"""Health status response."""
|
||||||
|
|
||||||
|
status: str = Field(..., description="Health status")
|
||||||
|
version: str = Field(..., description="Server version")
|
||||||
|
workspace_count: int = Field(default=0, description="Active workspaces")
|
||||||
|
gitea_connected: bool = Field(default=False, description="Gitea connectivity")
|
||||||
|
github_connected: bool = Field(default=False, description="GitHub connectivity")
|
||||||
|
gitlab_connected: bool = Field(default=False, description="GitLab connectivity")
|
||||||
|
redis_connected: bool = Field(default=False, description="Redis connectivity")
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderStatus(BaseModel):
|
||||||
|
"""Provider connection status."""
|
||||||
|
|
||||||
|
provider: str = Field(..., description="Provider name")
|
||||||
|
connected: bool = Field(..., description="Connection status")
|
||||||
|
url: str | None = Field(default=None, description="Provider URL")
|
||||||
|
user: str | None = Field(default=None, description="Authenticated user")
|
||||||
|
error: str | None = Field(default=None, description="Error if not connected")
|
||||||
11
mcp-servers/git-ops/providers/__init__.py
Normal file
11
mcp-servers/git-ops/providers/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
Git provider implementations.
|
||||||
|
|
||||||
|
Provides adapters for different git hosting platforms (Gitea, GitHub, GitLab).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base import BaseProvider
|
||||||
|
from .gitea import GiteaProvider
|
||||||
|
from .github import GitHubProvider
|
||||||
|
|
||||||
|
__all__ = ["BaseProvider", "GiteaProvider", "GitHubProvider"]
|
||||||
376
mcp-servers/git-ops/providers/base.py
Normal file
376
mcp-servers/git-ops/providers/base.py
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
"""
|
||||||
|
Base provider interface for git hosting platforms.
|
||||||
|
|
||||||
|
Defines the abstract interface that all git providers must implement.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from models import (
|
||||||
|
CreatePRResult,
|
||||||
|
GetPRResult,
|
||||||
|
ListPRsResult,
|
||||||
|
MergePRResult,
|
||||||
|
MergeStrategy,
|
||||||
|
PRState,
|
||||||
|
UpdatePRResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseProvider(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for git hosting providers.
|
||||||
|
|
||||||
|
All providers (Gitea, GitHub, GitLab) must implement this interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the provider name (e.g., 'gitea', 'github')."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def is_connected(self) -> bool:
|
||||||
|
"""Check if the provider is connected and authenticated."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_authenticated_user(self) -> str | None:
|
||||||
|
"""Get the username of the authenticated user."""
|
||||||
|
...
|
||||||
|
|
||||||
|
# Repository operations
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get repository information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner/organization
|
||||||
|
repo: Repository name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Repository info dict
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the default branch for a repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner/organization
|
||||||
|
repo: Repository name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Default branch name
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# Pull Request operations
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
title: str,
|
||||||
|
body: str,
|
||||||
|
source_branch: str,
|
||||||
|
target_branch: str,
|
||||||
|
draft: bool = False,
|
||||||
|
labels: list[str] | None = None,
|
||||||
|
assignees: list[str] | None = None,
|
||||||
|
reviewers: list[str] | None = None,
|
||||||
|
) -> CreatePRResult:
|
||||||
|
"""
|
||||||
|
Create a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
title: PR title
|
||||||
|
body: PR description
|
||||||
|
source_branch: Source branch name
|
||||||
|
target_branch: Target branch name
|
||||||
|
draft: Whether to create as draft
|
||||||
|
labels: Labels to add
|
||||||
|
assignees: Users to assign
|
||||||
|
reviewers: Users to request review from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CreatePRResult with PR number and URL
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
|
||||||
|
"""
|
||||||
|
Get a pull request by number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
pr_number: Pull request number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GetPRResult with PR details
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def list_prs(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
state: PRState | None = None,
|
||||||
|
author: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> ListPRsResult:
|
||||||
|
"""
|
||||||
|
List pull requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
state: Filter by state (open, closed, merged)
|
||||||
|
author: Filter by author
|
||||||
|
limit: Maximum PRs to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ListPRsResult with list of PRs
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def merge_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
|
||||||
|
commit_message: str | None = None,
|
||||||
|
delete_branch: bool = True,
|
||||||
|
) -> MergePRResult:
|
||||||
|
"""
|
||||||
|
Merge a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
pr_number: Pull request number
|
||||||
|
merge_strategy: Merge strategy to use
|
||||||
|
commit_message: Custom merge commit message
|
||||||
|
delete_branch: Whether to delete source branch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MergePRResult with merge status
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
title: str | None = None,
|
||||||
|
body: str | None = None,
|
||||||
|
state: PRState | None = None,
|
||||||
|
labels: list[str] | None = None,
|
||||||
|
assignees: list[str] | None = None,
|
||||||
|
) -> UpdatePRResult:
|
||||||
|
"""
|
||||||
|
Update a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
pr_number: Pull request number
|
||||||
|
title: New title
|
||||||
|
body: New description
|
||||||
|
state: New state (open, closed)
|
||||||
|
labels: Replace labels
|
||||||
|
assignees: Replace assignees
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UpdatePRResult with updated PR info
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def close_pr(self, owner: str, repo: str, pr_number: int) -> UpdatePRResult:
|
||||||
|
"""
|
||||||
|
Close a pull request without merging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
pr_number: Pull request number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UpdatePRResult with updated PR info
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# Branch operations via API (for operations that need to bypass local git)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_remote_branch(self, owner: str, repo: str, branch: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a remote branch via API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
branch: Branch name to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False otherwise
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_branch(
|
||||||
|
self, owner: str, repo: str, branch: str
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""
|
||||||
|
Get branch information via API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
branch: Branch name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Branch info dict or None if not found
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# Comment operations
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add_pr_comment(
|
||||||
|
self, owner: str, repo: str, pr_number: int, body: str
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Add a comment to a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
pr_number: Pull request number
|
||||||
|
body: Comment body
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created comment info
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def list_pr_comments(
|
||||||
|
self, owner: str, repo: str, pr_number: int
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List comments on a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
pr_number: Pull request number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of comments
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# Label operations
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def add_labels(
|
||||||
|
self, owner: str, repo: str, pr_number: int, labels: list[str]
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Add labels to a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
pr_number: Pull request number
|
||||||
|
labels: Labels to add
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated list of labels
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def remove_label(
|
||||||
|
self, owner: str, repo: str, pr_number: int, label: str
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Remove a label from a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
pr_number: Pull request number
|
||||||
|
label: Label to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated list of labels
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# Reviewer operations
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def request_review(
|
||||||
|
self, owner: str, repo: str, pr_number: int, reviewers: list[str]
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Request review from users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
pr_number: Pull request number
|
||||||
|
reviewers: Usernames to request review from
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of reviewers requested
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
# Utility methods
|
||||||
|
|
||||||
|
def parse_repo_url(self, repo_url: str) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Parse repository URL to extract owner and repo name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_url: Repository URL (HTTPS or SSH)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (owner, repo)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If URL cannot be parsed
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Handle SSH URLs: git@host:owner/repo.git
|
||||||
|
ssh_match = re.match(r"git@[^:]+:([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
|
||||||
|
if ssh_match:
|
||||||
|
return ssh_match.group(1), ssh_match.group(2)
|
||||||
|
|
||||||
|
# Handle HTTPS URLs: https://host/owner/repo.git
|
||||||
|
https_match = re.match(r"https?://[^/]+/([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
|
||||||
|
if https_match:
|
||||||
|
return https_match.group(1), https_match.group(2)
|
||||||
|
|
||||||
|
raise ValueError(f"Unable to parse repository URL: {repo_url}")
|
||||||
723
mcp-servers/git-ops/providers/gitea.py
Normal file
723
mcp-servers/git-ops/providers/gitea.py
Normal file
@@ -0,0 +1,723 @@
|
|||||||
|
"""
|
||||||
|
Gitea provider implementation.
|
||||||
|
|
||||||
|
Implements the BaseProvider interface for Gitea API operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from config import Settings, get_settings
|
||||||
|
from exceptions import (
|
||||||
|
APIError,
|
||||||
|
AuthenticationError,
|
||||||
|
PRNotFoundError,
|
||||||
|
)
|
||||||
|
from models import (
|
||||||
|
CreatePRResult,
|
||||||
|
GetPRResult,
|
||||||
|
ListPRsResult,
|
||||||
|
MergePRResult,
|
||||||
|
MergeStrategy,
|
||||||
|
PRInfo,
|
||||||
|
PRState,
|
||||||
|
UpdatePRResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .base import BaseProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GiteaProvider(BaseProvider):
|
||||||
|
"""
|
||||||
|
Gitea API provider implementation.
|
||||||
|
|
||||||
|
Supports all PR operations, branch operations, and repository queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str | None = None,
|
||||||
|
token: str | None = None,
|
||||||
|
settings: Settings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize Gitea provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_url: Gitea server URL (e.g., https://gitea.example.com)
|
||||||
|
token: API token
|
||||||
|
settings: Optional settings override
|
||||||
|
"""
|
||||||
|
self.settings = settings or get_settings()
|
||||||
|
self.base_url = (base_url or self.settings.gitea_base_url).rstrip("/")
|
||||||
|
self.token = token or self.settings.gitea_token
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
self._user: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the provider name."""
|
||||||
|
return "gitea"
|
||||||
|
|
||||||
|
async def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""Get or create HTTP client."""
|
||||||
|
if self._client is None:
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
if self.token:
|
||||||
|
headers["Authorization"] = f"token {self.token}"
|
||||||
|
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=f"{self.base_url}/api/v1",
|
||||||
|
headers=headers,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
async def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Make an API request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method
|
||||||
|
path: API path
|
||||||
|
**kwargs: Additional request arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
APIError: On API errors
|
||||||
|
AuthenticationError: On auth failures
|
||||||
|
"""
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.request(method, path, **kwargs)
|
||||||
|
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise AuthenticationError("gitea", "Invalid or expired token")
|
||||||
|
|
||||||
|
if response.status_code == 403:
|
||||||
|
raise AuthenticationError(
|
||||||
|
"gitea", "Insufficient permissions for this operation"
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if response.status_code >= 400:
|
||||||
|
error_msg = response.text
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
error_msg = error_data.get("message", error_msg)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise APIError("gitea", response.status_code, error_msg)
|
||||||
|
|
||||||
|
if response.status_code == 204:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise APIError("gitea", 0, f"Request failed: {e}")
|
||||||
|
|
||||||
|
async def is_connected(self) -> bool:
|
||||||
|
"""Check if connected to Gitea."""
|
||||||
|
if not self.base_url or not self.token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._request("GET", "/user")
|
||||||
|
return result is not None
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_authenticated_user(self) -> str | None:
|
||||||
|
"""Get the authenticated user's username."""
|
||||||
|
if self._user:
|
||||||
|
return self._user
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._request("GET", "/user")
|
||||||
|
if result:
|
||||||
|
self._user = result.get("login") or result.get("username")
|
||||||
|
return self._user
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Repository operations
|
||||||
|
|
||||||
|
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
|
||||||
|
"""Get repository information."""
|
||||||
|
result = await self._request("GET", f"/repos/{owner}/{repo}")
|
||||||
|
if result is None:
|
||||||
|
raise APIError("gitea", 404, f"Repository not found: {owner}/{repo}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||||
|
"""Get the default branch for a repository."""
|
||||||
|
repo_info = await self.get_repo_info(owner, repo)
|
||||||
|
return repo_info.get("default_branch", "main")
|
||||||
|
|
||||||
|
# Pull Request operations
|
||||||
|
|
||||||
|
async def create_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
title: str,
|
||||||
|
body: str,
|
||||||
|
source_branch: str,
|
||||||
|
target_branch: str,
|
||||||
|
draft: bool = False,
|
||||||
|
labels: list[str] | None = None,
|
||||||
|
assignees: list[str] | None = None,
|
||||||
|
reviewers: list[str] | None = None,
|
||||||
|
) -> CreatePRResult:
|
||||||
|
"""Create a pull request."""
|
||||||
|
try:
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"title": title,
|
||||||
|
"body": body,
|
||||||
|
"head": source_branch,
|
||||||
|
"base": target_branch,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Note: Gitea doesn't have draft PR support in all versions
|
||||||
|
# Draft support was added in Gitea 1.14+
|
||||||
|
|
||||||
|
result = await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/pulls",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return CreatePRResult(
|
||||||
|
success=False,
|
||||||
|
error="Failed to create pull request",
|
||||||
|
)
|
||||||
|
|
||||||
|
pr_number = result["number"]
|
||||||
|
|
||||||
|
# Add labels if specified
|
||||||
|
if labels:
|
||||||
|
await self.add_labels(owner, repo, pr_number, labels)
|
||||||
|
|
||||||
|
# Add assignees if specified (via issue update)
|
||||||
|
if assignees:
|
||||||
|
await self._request(
|
||||||
|
"PATCH",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||||
|
json={"assignees": assignees},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request reviewers if specified
|
||||||
|
if reviewers:
|
||||||
|
await self.request_review(owner, repo, pr_number, reviewers)
|
||||||
|
|
||||||
|
return CreatePRResult(
|
||||||
|
success=True,
|
||||||
|
pr_number=pr_number,
|
||||||
|
pr_url=result.get("html_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
return CreatePRResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
|
||||||
|
"""Get a pull request by number."""
|
||||||
|
try:
|
||||||
|
result = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
|
||||||
|
|
||||||
|
pr_info = self._parse_pr(result)
|
||||||
|
|
||||||
|
return GetPRResult(
|
||||||
|
success=True,
|
||||||
|
pr=pr_info.to_dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
except PRNotFoundError:
|
||||||
|
return GetPRResult(
|
||||||
|
success=False,
|
||||||
|
error=f"Pull request #{pr_number} not found",
|
||||||
|
)
|
||||||
|
except APIError as e:
|
||||||
|
return GetPRResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_prs(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
state: PRState | None = None,
|
||||||
|
author: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> ListPRsResult:
|
||||||
|
"""List pull requests."""
|
||||||
|
try:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
if state:
|
||||||
|
# Gitea uses different state names
|
||||||
|
if state == PRState.OPEN:
|
||||||
|
params["state"] = "open"
|
||||||
|
elif state == PRState.CLOSED or state == PRState.MERGED:
|
||||||
|
params["state"] = "closed"
|
||||||
|
else:
|
||||||
|
params["state"] = "all"
|
||||||
|
|
||||||
|
result = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/pulls",
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return ListPRsResult(
|
||||||
|
success=True,
|
||||||
|
pull_requests=[],
|
||||||
|
total_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
prs = []
|
||||||
|
for pr_data in result:
|
||||||
|
# Filter by author if specified
|
||||||
|
if author:
|
||||||
|
pr_author = pr_data.get("user", {}).get("login", "")
|
||||||
|
if pr_author.lower() != author.lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Filter merged PRs if looking specifically for merged
|
||||||
|
if state == PRState.MERGED:
|
||||||
|
if not pr_data.get("merged"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
pr_info = self._parse_pr(pr_data)
|
||||||
|
prs.append(pr_info.to_dict())
|
||||||
|
|
||||||
|
return ListPRsResult(
|
||||||
|
success=True,
|
||||||
|
pull_requests=prs,
|
||||||
|
total_count=len(prs),
|
||||||
|
)
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
return ListPRsResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def merge_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
|
||||||
|
commit_message: str | None = None,
|
||||||
|
delete_branch: bool = True,
|
||||||
|
) -> MergePRResult:
|
||||||
|
"""Merge a pull request."""
|
||||||
|
try:
|
||||||
|
# Map merge strategy to Gitea's "Do" values
|
||||||
|
do_map = {
|
||||||
|
MergeStrategy.MERGE: "merge",
|
||||||
|
MergeStrategy.SQUASH: "squash",
|
||||||
|
MergeStrategy.REBASE: "rebase",
|
||||||
|
}
|
||||||
|
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"Do": do_map[merge_strategy],
|
||||||
|
"delete_branch_after_merge": delete_branch,
|
||||||
|
}
|
||||||
|
|
||||||
|
if commit_message:
|
||||||
|
data["MergeTitleField"] = commit_message.split("\n")[0]
|
||||||
|
if "\n" in commit_message:
|
||||||
|
data["MergeMessageField"] = "\n".join(
|
||||||
|
commit_message.split("\n")[1:]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
# Check if PR was actually merged
|
||||||
|
pr_result = await self.get_pr(owner, repo, pr_number)
|
||||||
|
if pr_result.success and pr_result.pr:
|
||||||
|
if pr_result.pr.get("state") == "merged":
|
||||||
|
return MergePRResult(
|
||||||
|
success=True,
|
||||||
|
branch_deleted=delete_branch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return MergePRResult(
|
||||||
|
success=False,
|
||||||
|
error="Failed to merge pull request",
|
||||||
|
)
|
||||||
|
|
||||||
|
return MergePRResult(
|
||||||
|
success=True,
|
||||||
|
merge_commit_sha=result.get("sha"),
|
||||||
|
branch_deleted=delete_branch,
|
||||||
|
)
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
return MergePRResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
title: str | None = None,
|
||||||
|
body: str | None = None,
|
||||||
|
state: PRState | None = None,
|
||||||
|
labels: list[str] | None = None,
|
||||||
|
assignees: list[str] | None = None,
|
||||||
|
) -> UpdatePRResult:
|
||||||
|
"""Update a pull request."""
|
||||||
|
try:
|
||||||
|
data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
if title is not None:
|
||||||
|
data["title"] = title
|
||||||
|
if body is not None:
|
||||||
|
data["body"] = body
|
||||||
|
if state is not None:
|
||||||
|
if state == PRState.OPEN:
|
||||||
|
data["state"] = "open"
|
||||||
|
elif state == PRState.CLOSED:
|
||||||
|
data["state"] = "closed"
|
||||||
|
|
||||||
|
# Update PR if there's data
|
||||||
|
if data:
|
||||||
|
await self._request(
|
||||||
|
"PATCH",
|
||||||
|
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update labels via issue endpoint
|
||||||
|
if labels is not None:
|
||||||
|
# First clear existing labels
|
||||||
|
await self._request(
|
||||||
|
"DELETE",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||||
|
)
|
||||||
|
# Then add new labels
|
||||||
|
if labels:
|
||||||
|
await self.add_labels(owner, repo, pr_number, labels)
|
||||||
|
|
||||||
|
# Update assignees via issue endpoint
|
||||||
|
if assignees is not None:
|
||||||
|
await self._request(
|
||||||
|
"PATCH",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||||
|
json={"assignees": assignees},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch updated PR
|
||||||
|
result = await self.get_pr(owner, repo, pr_number)
|
||||||
|
return UpdatePRResult(
|
||||||
|
success=result.success,
|
||||||
|
pr=result.pr,
|
||||||
|
error=result.error,
|
||||||
|
)
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
return UpdatePRResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
) -> UpdatePRResult:
|
||||||
|
"""Close a pull request without merging."""
|
||||||
|
return await self.update_pr(
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pr_number,
|
||||||
|
state=PRState.CLOSED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Branch operations
|
||||||
|
|
||||||
|
async def delete_remote_branch(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
branch: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Delete a remote branch."""
|
||||||
|
try:
|
||||||
|
await self._request(
|
||||||
|
"DELETE",
|
||||||
|
f"/repos/{owner}/{repo}/branches/{branch}",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except APIError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_branch(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
branch: str,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Get branch information."""
|
||||||
|
return await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/branches/{branch}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Comment operations
|
||||||
|
|
||||||
|
async def add_pr_comment(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
body: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Add a comment to a pull request."""
|
||||||
|
result = await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||||
|
json={"body": body},
|
||||||
|
)
|
||||||
|
return result or {}
|
||||||
|
|
||||||
|
async def list_pr_comments(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""List comments on a pull request."""
|
||||||
|
result = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||||
|
)
|
||||||
|
return result or []
|
||||||
|
|
||||||
|
# Label operations
|
||||||
|
|
||||||
|
async def add_labels(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
labels: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Add labels to a pull request."""
|
||||||
|
# First, get or create label IDs
|
||||||
|
label_ids = []
|
||||||
|
for label_name in labels:
|
||||||
|
label_id = await self._get_or_create_label(owner, repo, label_name)
|
||||||
|
if label_id:
|
||||||
|
label_ids.append(label_id)
|
||||||
|
|
||||||
|
if label_ids:
|
||||||
|
await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||||
|
json={"labels": label_ids},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return current labels
|
||||||
|
issue = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||||
|
)
|
||||||
|
if issue:
|
||||||
|
return [lbl["name"] for lbl in issue.get("labels", [])]
|
||||||
|
return labels
|
||||||
|
|
||||||
|
async def remove_label(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
label: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Remove a label from a pull request."""
|
||||||
|
# Get label ID
|
||||||
|
label_info = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/labels?name={label}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if label_info and len(label_info) > 0:
|
||||||
|
label_id = label_info[0]["id"]
|
||||||
|
await self._request(
|
||||||
|
"DELETE",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return remaining labels
|
||||||
|
issue = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||||
|
)
|
||||||
|
if issue:
|
||||||
|
return [lbl["name"] for lbl in issue.get("labels", [])]
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _get_or_create_label(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
label_name: str,
|
||||||
|
) -> int | None:
|
||||||
|
"""Get or create a label and return its ID."""
|
||||||
|
# Try to find existing label
|
||||||
|
labels = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/labels",
|
||||||
|
)
|
||||||
|
|
||||||
|
if labels:
|
||||||
|
for label in labels:
|
||||||
|
if label["name"].lower() == label_name.lower():
|
||||||
|
return label["id"]
|
||||||
|
|
||||||
|
# Create new label with default color
|
||||||
|
try:
|
||||||
|
result = await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/labels",
|
||||||
|
json={
|
||||||
|
"name": label_name,
|
||||||
|
"color": "#3B82F6", # Default blue
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
return result["id"]
|
||||||
|
except APIError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Reviewer operations
|
||||||
|
|
||||||
|
async def request_review(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
reviewers: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Request review from users."""
|
||||||
|
await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
|
||||||
|
json={"reviewers": reviewers},
|
||||||
|
)
|
||||||
|
return reviewers
|
||||||
|
|
||||||
|
# Helper methods
|
||||||
|
|
||||||
|
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
|
||||||
|
"""Parse PR API response into PRInfo."""
|
||||||
|
# Parse dates
|
||||||
|
created_at = self._parse_datetime(data.get("created_at"))
|
||||||
|
updated_at = self._parse_datetime(data.get("updated_at"))
|
||||||
|
merged_at = self._parse_datetime(data.get("merged_at"))
|
||||||
|
closed_at = self._parse_datetime(data.get("closed_at"))
|
||||||
|
|
||||||
|
# Determine state
|
||||||
|
if data.get("merged"):
|
||||||
|
state = PRState.MERGED
|
||||||
|
elif data.get("state") == "closed":
|
||||||
|
state = PRState.CLOSED
|
||||||
|
else:
|
||||||
|
state = PRState.OPEN
|
||||||
|
|
||||||
|
# Extract labels
|
||||||
|
labels = [lbl["name"] for lbl in data.get("labels", [])]
|
||||||
|
|
||||||
|
# Extract assignees
|
||||||
|
assignees = [a["login"] for a in data.get("assignees", [])]
|
||||||
|
|
||||||
|
# Extract reviewers
|
||||||
|
reviewers = []
|
||||||
|
if "requested_reviewers" in data:
|
||||||
|
reviewers = [r["login"] for r in data["requested_reviewers"]]
|
||||||
|
|
||||||
|
return PRInfo(
|
||||||
|
number=data["number"],
|
||||||
|
title=data["title"],
|
||||||
|
body=data.get("body", ""),
|
||||||
|
state=state,
|
||||||
|
source_branch=data.get("head", {}).get("ref", ""),
|
||||||
|
target_branch=data.get("base", {}).get("ref", ""),
|
||||||
|
author=data.get("user", {}).get("login", ""),
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
merged_at=merged_at,
|
||||||
|
closed_at=closed_at,
|
||||||
|
url=data.get("html_url"),
|
||||||
|
labels=labels,
|
||||||
|
assignees=assignees,
|
||||||
|
reviewers=reviewers,
|
||||||
|
mergeable=data.get("mergeable"),
|
||||||
|
draft=data.get("draft", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_datetime(self, value: str | None) -> datetime:
|
||||||
|
"""Parse datetime string from API."""
|
||||||
|
if not value:
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Handle Gitea's datetime format
|
||||||
|
if value.endswith("Z"):
|
||||||
|
value = value[:-1] + "+00:00"
|
||||||
|
return datetime.fromisoformat(value)
|
||||||
|
except ValueError:
|
||||||
|
return datetime.now(UTC)
|
||||||
675
mcp-servers/git-ops/providers/github.py
Normal file
675
mcp-servers/git-ops/providers/github.py
Normal file
@@ -0,0 +1,675 @@
|
|||||||
|
"""
|
||||||
|
GitHub provider implementation.
|
||||||
|
|
||||||
|
Implements the BaseProvider interface for GitHub API operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from config import Settings, get_settings
|
||||||
|
from exceptions import (
|
||||||
|
APIError,
|
||||||
|
AuthenticationError,
|
||||||
|
PRNotFoundError,
|
||||||
|
)
|
||||||
|
from models import (
|
||||||
|
CreatePRResult,
|
||||||
|
GetPRResult,
|
||||||
|
ListPRsResult,
|
||||||
|
MergePRResult,
|
||||||
|
MergeStrategy,
|
||||||
|
PRInfo,
|
||||||
|
PRState,
|
||||||
|
UpdatePRResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .base import BaseProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubProvider(BaseProvider):
|
||||||
|
"""
|
||||||
|
GitHub API provider implementation.
|
||||||
|
|
||||||
|
Supports all PR operations, branch operations, and repository queries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
token: str | None = None,
|
||||||
|
settings: Settings | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize GitHub provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: GitHub personal access token or fine-grained token
|
||||||
|
settings: Optional settings override
|
||||||
|
"""
|
||||||
|
self.settings = settings or get_settings()
|
||||||
|
self.token = token or self.settings.github_token
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
self._user: str | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Return the provider name."""
|
||||||
|
return "github"
|
||||||
|
|
||||||
|
async def _get_client(self) -> httpx.AsyncClient:
|
||||||
|
"""Get or create HTTP client."""
|
||||||
|
if self._client is None:
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/vnd.github+json",
|
||||||
|
"X-GitHub-Api-Version": "2022-11-28",
|
||||||
|
}
|
||||||
|
if self.token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url="https://api.github.com",
|
||||||
|
headers=headers,
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the HTTP client."""
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
async def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Make an API request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method
|
||||||
|
path: API path
|
||||||
|
**kwargs: Additional request arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
APIError: On API errors
|
||||||
|
AuthenticationError: On auth failures
|
||||||
|
"""
|
||||||
|
client = await self._get_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.request(method, path, **kwargs)
|
||||||
|
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise AuthenticationError("github", "Invalid or expired token")
|
||||||
|
|
||||||
|
if response.status_code == 403:
|
||||||
|
# Check for rate limiting
|
||||||
|
if "rate limit" in response.text.lower():
|
||||||
|
raise APIError("github", 403, "GitHub API rate limit exceeded")
|
||||||
|
raise AuthenticationError(
|
||||||
|
"github", "Insufficient permissions for this operation"
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 404:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if response.status_code >= 400:
|
||||||
|
error_msg = response.text
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
error_msg = error_data.get("message", error_msg)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise APIError("github", response.status_code, error_msg)
|
||||||
|
|
||||||
|
if response.status_code == 204:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
raise APIError("github", 0, f"Request failed: {e}")
|
||||||
|
|
||||||
|
async def is_connected(self) -> bool:
|
||||||
|
"""Check if connected to GitHub."""
|
||||||
|
if not self.token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._request("GET", "/user")
|
||||||
|
return result is not None
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_authenticated_user(self) -> str | None:
|
||||||
|
"""Get the authenticated user's username."""
|
||||||
|
if self._user:
|
||||||
|
return self._user
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await self._request("GET", "/user")
|
||||||
|
if result:
|
||||||
|
self._user = result.get("login")
|
||||||
|
return self._user
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Repository operations
|
||||||
|
|
||||||
|
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
|
||||||
|
"""Get repository information."""
|
||||||
|
result = await self._request("GET", f"/repos/{owner}/{repo}")
|
||||||
|
if result is None:
|
||||||
|
raise APIError("github", 404, f"Repository not found: {owner}/{repo}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_default_branch(self, owner: str, repo: str) -> str:
|
||||||
|
"""Get the default branch for a repository."""
|
||||||
|
repo_info = await self.get_repo_info(owner, repo)
|
||||||
|
return repo_info.get("default_branch", "main")
|
||||||
|
|
||||||
|
# Pull Request operations
|
||||||
|
|
||||||
|
async def create_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
title: str,
|
||||||
|
body: str,
|
||||||
|
source_branch: str,
|
||||||
|
target_branch: str,
|
||||||
|
draft: bool = False,
|
||||||
|
labels: list[str] | None = None,
|
||||||
|
assignees: list[str] | None = None,
|
||||||
|
reviewers: list[str] | None = None,
|
||||||
|
) -> CreatePRResult:
|
||||||
|
"""Create a pull request."""
|
||||||
|
try:
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"title": title,
|
||||||
|
"body": body,
|
||||||
|
"head": source_branch,
|
||||||
|
"base": target_branch,
|
||||||
|
"draft": draft,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/pulls",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return CreatePRResult(
|
||||||
|
success=False,
|
||||||
|
error="Failed to create pull request",
|
||||||
|
)
|
||||||
|
|
||||||
|
pr_number = result["number"]
|
||||||
|
|
||||||
|
# Add labels if specified
|
||||||
|
if labels:
|
||||||
|
await self.add_labels(owner, repo, pr_number, labels)
|
||||||
|
|
||||||
|
# Add assignees if specified
|
||||||
|
if assignees:
|
||||||
|
await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
||||||
|
json={"assignees": assignees},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request reviewers if specified
|
||||||
|
if reviewers:
|
||||||
|
await self.request_review(owner, repo, pr_number, reviewers)
|
||||||
|
|
||||||
|
return CreatePRResult(
|
||||||
|
success=True,
|
||||||
|
pr_number=pr_number,
|
||||||
|
pr_url=result.get("html_url"),
|
||||||
|
)
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
return CreatePRResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
|
||||||
|
"""Get a pull request by number."""
|
||||||
|
try:
|
||||||
|
result = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
|
||||||
|
|
||||||
|
pr_info = self._parse_pr(result)
|
||||||
|
|
||||||
|
return GetPRResult(
|
||||||
|
success=True,
|
||||||
|
pr=pr_info.to_dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
except PRNotFoundError:
|
||||||
|
return GetPRResult(
|
||||||
|
success=False,
|
||||||
|
error=f"Pull request #{pr_number} not found",
|
||||||
|
)
|
||||||
|
except APIError as e:
|
||||||
|
return GetPRResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_prs(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
state: PRState | None = None,
|
||||||
|
author: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> ListPRsResult:
|
||||||
|
"""List pull requests."""
|
||||||
|
try:
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"per_page": min(limit, 100), # GitHub max is 100
|
||||||
|
}
|
||||||
|
|
||||||
|
if state:
|
||||||
|
# GitHub uses 'state' for open/closed only
|
||||||
|
# Merged PRs are closed PRs with merged_at set
|
||||||
|
if state == PRState.OPEN:
|
||||||
|
params["state"] = "open"
|
||||||
|
elif state in (PRState.CLOSED, PRState.MERGED):
|
||||||
|
params["state"] = "closed"
|
||||||
|
else:
|
||||||
|
params["state"] = "all"
|
||||||
|
|
||||||
|
result = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/pulls",
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return ListPRsResult(
|
||||||
|
success=True,
|
||||||
|
pull_requests=[],
|
||||||
|
total_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
prs = []
|
||||||
|
for pr_data in result:
|
||||||
|
# Filter by author if specified
|
||||||
|
if author:
|
||||||
|
pr_author = pr_data.get("user", {}).get("login", "")
|
||||||
|
if pr_author.lower() != author.lower():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Filter merged PRs if looking specifically for merged
|
||||||
|
if state == PRState.MERGED:
|
||||||
|
if not pr_data.get("merged_at"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
pr_info = self._parse_pr(pr_data)
|
||||||
|
prs.append(pr_info.to_dict())
|
||||||
|
|
||||||
|
return ListPRsResult(
|
||||||
|
success=True,
|
||||||
|
pull_requests=prs,
|
||||||
|
total_count=len(prs),
|
||||||
|
)
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
return ListPRsResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def merge_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
|
||||||
|
commit_message: str | None = None,
|
||||||
|
delete_branch: bool = True,
|
||||||
|
) -> MergePRResult:
|
||||||
|
"""Merge a pull request."""
|
||||||
|
try:
|
||||||
|
# Map merge strategy to GitHub's merge_method values
|
||||||
|
method_map = {
|
||||||
|
MergeStrategy.MERGE: "merge",
|
||||||
|
MergeStrategy.SQUASH: "squash",
|
||||||
|
MergeStrategy.REBASE: "rebase",
|
||||||
|
}
|
||||||
|
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"merge_method": method_map[merge_strategy],
|
||||||
|
}
|
||||||
|
|
||||||
|
if commit_message:
|
||||||
|
# For squash, commit_title and commit_message
|
||||||
|
# For merge, commit_title and commit_message
|
||||||
|
parts = commit_message.split("\n", 1)
|
||||||
|
data["commit_title"] = parts[0]
|
||||||
|
if len(parts) > 1:
|
||||||
|
data["commit_message"] = parts[1]
|
||||||
|
|
||||||
|
result = await self._request(
|
||||||
|
"PUT",
|
||||||
|
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return MergePRResult(
|
||||||
|
success=False,
|
||||||
|
error="Failed to merge pull request",
|
||||||
|
)
|
||||||
|
|
||||||
|
branch_deleted = False
|
||||||
|
# Delete branch if requested
|
||||||
|
if delete_branch and result.get("merged"):
|
||||||
|
# Get PR to find the branch name
|
||||||
|
pr_result = await self.get_pr(owner, repo, pr_number)
|
||||||
|
if pr_result.success and pr_result.pr:
|
||||||
|
source_branch = pr_result.pr.get("source_branch")
|
||||||
|
if source_branch:
|
||||||
|
branch_deleted = await self.delete_remote_branch(
|
||||||
|
owner, repo, source_branch
|
||||||
|
)
|
||||||
|
|
||||||
|
return MergePRResult(
|
||||||
|
success=True,
|
||||||
|
merge_commit_sha=result.get("sha"),
|
||||||
|
branch_deleted=branch_deleted,
|
||||||
|
)
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
return MergePRResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
title: str | None = None,
|
||||||
|
body: str | None = None,
|
||||||
|
state: PRState | None = None,
|
||||||
|
labels: list[str] | None = None,
|
||||||
|
assignees: list[str] | None = None,
|
||||||
|
) -> UpdatePRResult:
|
||||||
|
"""Update a pull request."""
|
||||||
|
try:
|
||||||
|
data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
if title is not None:
|
||||||
|
data["title"] = title
|
||||||
|
if body is not None:
|
||||||
|
data["body"] = body
|
||||||
|
if state is not None:
|
||||||
|
if state == PRState.OPEN:
|
||||||
|
data["state"] = "open"
|
||||||
|
elif state == PRState.CLOSED:
|
||||||
|
data["state"] = "closed"
|
||||||
|
|
||||||
|
# Update PR if there's data
|
||||||
|
if data:
|
||||||
|
await self._request(
|
||||||
|
"PATCH",
|
||||||
|
f"/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update labels via issue endpoint
|
||||||
|
if labels is not None:
|
||||||
|
await self._request(
|
||||||
|
"PUT",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||||
|
json={"labels": labels},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update assignees via issue endpoint
|
||||||
|
if assignees is not None:
|
||||||
|
# First remove all assignees
|
||||||
|
await self._request(
|
||||||
|
"DELETE",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
||||||
|
json={"assignees": []},
|
||||||
|
)
|
||||||
|
# Then add new ones
|
||||||
|
if assignees:
|
||||||
|
await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
|
||||||
|
json={"assignees": assignees},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch updated PR
|
||||||
|
result = await self.get_pr(owner, repo, pr_number)
|
||||||
|
return UpdatePRResult(
|
||||||
|
success=result.success,
|
||||||
|
pr=result.pr,
|
||||||
|
error=result.error,
|
||||||
|
)
|
||||||
|
|
||||||
|
except APIError as e:
|
||||||
|
return UpdatePRResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close_pr(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
) -> UpdatePRResult:
|
||||||
|
"""Close a pull request without merging."""
|
||||||
|
return await self.update_pr(
|
||||||
|
owner,
|
||||||
|
repo,
|
||||||
|
pr_number,
|
||||||
|
state=PRState.CLOSED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Branch operations
|
||||||
|
|
||||||
|
async def delete_remote_branch(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
branch: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Delete a remote branch."""
|
||||||
|
try:
|
||||||
|
await self._request(
|
||||||
|
"DELETE",
|
||||||
|
f"/repos/{owner}/{repo}/git/refs/heads/{branch}",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except APIError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_branch(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
branch: str,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Get branch information."""
|
||||||
|
return await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/branches/{branch}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Comment operations
|
||||||
|
|
||||||
|
async def add_pr_comment(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
body: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Add a comment to a pull request."""
|
||||||
|
result = await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||||
|
json={"body": body},
|
||||||
|
)
|
||||||
|
return result or {}
|
||||||
|
|
||||||
|
async def list_pr_comments(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""List comments on a pull request."""
|
||||||
|
result = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
|
||||||
|
)
|
||||||
|
return result or []
|
||||||
|
|
||||||
|
# Label operations
|
||||||
|
|
||||||
|
async def add_labels(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
labels: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Add labels to a pull request."""
|
||||||
|
# GitHub creates labels automatically if they don't exist (unlike Gitea)
|
||||||
|
result = await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
|
||||||
|
json={"labels": labels},
|
||||||
|
)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
return [lbl["name"] for lbl in result]
|
||||||
|
return labels
|
||||||
|
|
||||||
|
async def remove_label(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
label: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Remove a label from a pull request."""
|
||||||
|
await self._request(
|
||||||
|
"DELETE",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return remaining labels
|
||||||
|
issue = await self._request(
|
||||||
|
"GET",
|
||||||
|
f"/repos/{owner}/{repo}/issues/{pr_number}",
|
||||||
|
)
|
||||||
|
if issue:
|
||||||
|
return [lbl["name"] for lbl in issue.get("labels", [])]
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Reviewer operations
|
||||||
|
|
||||||
|
async def request_review(
|
||||||
|
self,
|
||||||
|
owner: str,
|
||||||
|
repo: str,
|
||||||
|
pr_number: int,
|
||||||
|
reviewers: list[str],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Request review from users."""
|
||||||
|
await self._request(
|
||||||
|
"POST",
|
||||||
|
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
|
||||||
|
json={"reviewers": reviewers},
|
||||||
|
)
|
||||||
|
return reviewers
|
||||||
|
|
||||||
|
# Helper methods
|
||||||
|
|
||||||
|
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
|
||||||
|
"""Parse PR API response into PRInfo."""
|
||||||
|
# Parse dates
|
||||||
|
created_at = self._parse_datetime(data.get("created_at"))
|
||||||
|
updated_at = self._parse_datetime(data.get("updated_at"))
|
||||||
|
merged_at = self._parse_datetime(data.get("merged_at"))
|
||||||
|
closed_at = self._parse_datetime(data.get("closed_at"))
|
||||||
|
|
||||||
|
# Determine state
|
||||||
|
if data.get("merged_at"):
|
||||||
|
state = PRState.MERGED
|
||||||
|
elif data.get("state") == "closed":
|
||||||
|
state = PRState.CLOSED
|
||||||
|
else:
|
||||||
|
state = PRState.OPEN
|
||||||
|
|
||||||
|
# Extract labels
|
||||||
|
labels = [lbl["name"] for lbl in data.get("labels", [])]
|
||||||
|
|
||||||
|
# Extract assignees
|
||||||
|
assignees = [a["login"] for a in data.get("assignees", [])]
|
||||||
|
|
||||||
|
# Extract reviewers
|
||||||
|
reviewers = []
|
||||||
|
if "requested_reviewers" in data:
|
||||||
|
reviewers = [r["login"] for r in data["requested_reviewers"]]
|
||||||
|
|
||||||
|
return PRInfo(
|
||||||
|
number=data["number"],
|
||||||
|
title=data["title"],
|
||||||
|
body=data.get("body", "") or "",
|
||||||
|
state=state,
|
||||||
|
source_branch=data.get("head", {}).get("ref", ""),
|
||||||
|
target_branch=data.get("base", {}).get("ref", ""),
|
||||||
|
author=data.get("user", {}).get("login", ""),
|
||||||
|
created_at=created_at,
|
||||||
|
updated_at=updated_at,
|
||||||
|
merged_at=merged_at,
|
||||||
|
closed_at=closed_at,
|
||||||
|
url=data.get("html_url"),
|
||||||
|
labels=labels,
|
||||||
|
assignees=assignees,
|
||||||
|
reviewers=reviewers,
|
||||||
|
mergeable=data.get("mergeable"),
|
||||||
|
draft=data.get("draft", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_datetime(self, value: str | None) -> datetime:
|
||||||
|
"""Parse datetime string from API."""
|
||||||
|
if not value:
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# GitHub uses ISO 8601 format with Z suffix
|
||||||
|
if value.endswith("Z"):
|
||||||
|
value = value[:-1] + "+00:00"
|
||||||
|
return datetime.fromisoformat(value)
|
||||||
|
except ValueError:
|
||||||
|
return datetime.now(UTC)
|
||||||
120
mcp-servers/git-ops/pyproject.toml
Normal file
120
mcp-servers/git-ops/pyproject.toml
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
[project]
|
||||||
|
name = "syndarix-mcp-git-ops"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Syndarix Git Operations MCP Server - Repository management, branching, commits, and PR workflows"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"fastmcp>=2.0.0",
|
||||||
|
"gitpython>=3.1.0",
|
||||||
|
"httpx>=0.27.0",
|
||||||
|
"redis>=5.0.0",
|
||||||
|
"pydantic>=2.0.0",
|
||||||
|
"pydantic-settings>=2.0.0",
|
||||||
|
"uvicorn>=0.30.0",
|
||||||
|
"fastapi>=0.115.0",
|
||||||
|
"filelock>=3.15.0",
|
||||||
|
"aiofiles>=24.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pytest-asyncio>=0.24.0",
|
||||||
|
"pytest-cov>=5.0.0",
|
||||||
|
"fakeredis>=2.25.0",
|
||||||
|
"ruff>=0.8.0",
|
||||||
|
"mypy>=1.11.0",
|
||||||
|
"respx>=0.21.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
git-ops = "server:main"
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["."]
|
||||||
|
exclude = ["tests/", "*.md", "Dockerfile"]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.sdist]
|
||||||
|
include = ["*.py", "pyproject.toml"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py312"
|
||||||
|
line-length = 88
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"ARG", # flake8-unused-arguments
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
"S", # flake8-bandit (security)
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by formatter)
|
||||||
|
"B008", # do not perform function calls in argument defaults
|
||||||
|
"B904", # raise from in except (too noisy)
|
||||||
|
"S104", # possible binding to all interfaces
|
||||||
|
"S110", # try-except-pass (intentional for optional operations)
|
||||||
|
"S603", # subprocess without shell=True (safe usage in git wrapper)
|
||||||
|
"S607", # starting a process with a partial path (git CLI)
|
||||||
|
"ARG002", # unused method arguments (for API compatibility)
|
||||||
|
"SIM102", # nested if statements (sometimes more readable)
|
||||||
|
"SIM105", # contextlib.suppress (sometimes more readable)
|
||||||
|
"SIM108", # ternary operator (sometimes more readable)
|
||||||
|
"SIM118", # dict.keys() (explicit is fine)
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
known-first-party = ["config", "models", "exceptions", "git_wrapper", "workspace", "providers"]
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"tests/**/*.py" = ["S101", "ARG001", "S105", "S106", "S108", "F841", "B007"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
addopts = "-v --tb=short"
|
||||||
|
filterwarnings = [
|
||||||
|
"ignore::DeprecationWarning",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.coverage.run]
|
||||||
|
source = ["."]
|
||||||
|
omit = ["tests/*", "conftest.py"]
|
||||||
|
branch = true
|
||||||
|
|
||||||
|
[tool.coverage.report]
|
||||||
|
exclude_lines = [
|
||||||
|
"pragma: no cover",
|
||||||
|
"def __repr__",
|
||||||
|
"raise NotImplementedError",
|
||||||
|
"if TYPE_CHECKING:",
|
||||||
|
"if __name__ == .__main__.:",
|
||||||
|
]
|
||||||
|
fail_under = 78
|
||||||
|
show_missing = true
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.12"
|
||||||
|
warn_return_any = false
|
||||||
|
warn_unused_ignores = false
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
ignore_missing_imports = true
|
||||||
|
plugins = ["pydantic.mypy"]
|
||||||
|
files = ["server.py", "config.py", "models.py", "exceptions.py", "git_wrapper.py", "workspace.py", "providers/"]
|
||||||
|
exclude = ["tests/"]
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = "tests.*"
|
||||||
|
disallow_untyped_defs = false
|
||||||
|
ignore_errors = true
|
||||||
1674
mcp-servers/git-ops/server.py
Normal file
1674
mcp-servers/git-ops/server.py
Normal file
File diff suppressed because it is too large
Load Diff
1
mcp-servers/git-ops/tests/__init__.py
Normal file
1
mcp-servers/git-ops/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Tests for Git Operations MCP Server."""
|
||||||
299
mcp-servers/git-ops/tests/conftest.py
Normal file
299
mcp-servers/git-ops/tests/conftest.py
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
"""
|
||||||
|
Test configuration and fixtures for Git Operations MCP Server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from collections.abc import AsyncIterator, Iterator
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from git import Repo as GitRepo
|
||||||
|
|
||||||
|
# Set test environment
|
||||||
|
os.environ["IS_TEST"] = "true"
|
||||||
|
os.environ["GIT_OPS_WORKSPACE_BASE_PATH"] = "/tmp/test-workspaces"
|
||||||
|
os.environ["GIT_OPS_GITEA_BASE_URL"] = "https://gitea.test.com"
|
||||||
|
os.environ["GIT_OPS_GITEA_TOKEN"] = "test-token"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def reset_settings_session():
|
||||||
|
"""Reset settings at start and end of test session."""
|
||||||
|
from config import reset_settings
|
||||||
|
|
||||||
|
reset_settings()
|
||||||
|
yield
|
||||||
|
reset_settings()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reset_settings():
|
||||||
|
"""Reset settings before each test that needs it."""
|
||||||
|
from config import reset_settings
|
||||||
|
|
||||||
|
reset_settings()
|
||||||
|
yield
|
||||||
|
reset_settings()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_settings():
|
||||||
|
"""Get test settings."""
|
||||||
|
from config import Settings
|
||||||
|
|
||||||
|
return Settings(
|
||||||
|
workspace_base_path=Path("/tmp/test-workspaces"),
|
||||||
|
gitea_base_url="https://gitea.test.com",
|
||||||
|
gitea_token="test-token",
|
||||||
|
github_token="github-test-token",
|
||||||
|
git_author_name="Test Agent",
|
||||||
|
git_author_email="test@syndarix.ai",
|
||||||
|
enable_force_push=False,
|
||||||
|
debug=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir() -> Iterator[Path]:
|
||||||
|
"""Create a temporary directory for tests."""
|
||||||
|
temp_path = Path(tempfile.mkdtemp())
|
||||||
|
yield temp_path
|
||||||
|
if temp_path.exists():
|
||||||
|
shutil.rmtree(temp_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_workspace(temp_dir: Path) -> Path:
|
||||||
|
"""Create a temporary workspace directory."""
|
||||||
|
workspace = temp_dir / "workspace"
|
||||||
|
workspace.mkdir(parents=True, exist_ok=True)
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def git_repo(temp_workspace: Path) -> GitRepo:
|
||||||
|
"""Create a git repository in the temp workspace."""
|
||||||
|
# Initialize with main branch (Git 2.28+)
|
||||||
|
repo = GitRepo.init(temp_workspace, initial_branch="main")
|
||||||
|
|
||||||
|
# Configure git
|
||||||
|
with repo.config_writer() as cw:
|
||||||
|
cw.set_value("user", "name", "Test User")
|
||||||
|
cw.set_value("user", "email", "test@example.com")
|
||||||
|
|
||||||
|
# Create initial commit
|
||||||
|
test_file = temp_workspace / "README.md"
|
||||||
|
test_file.write_text("# Test Repository\n")
|
||||||
|
repo.index.add(["README.md"])
|
||||||
|
repo.index.commit("Initial commit")
|
||||||
|
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def git_repo_with_remote(git_repo: GitRepo, temp_dir: Path) -> tuple[GitRepo, GitRepo]:
|
||||||
|
"""Create a git repository with a 'remote' (bare repo)."""
|
||||||
|
# Create bare repo as remote
|
||||||
|
remote_path = temp_dir / "remote.git"
|
||||||
|
remote_repo = GitRepo.init(remote_path, bare=True)
|
||||||
|
|
||||||
|
# Add remote to main repo
|
||||||
|
git_repo.create_remote("origin", str(remote_path))
|
||||||
|
|
||||||
|
# Push initial commit
|
||||||
|
git_repo.remotes.origin.push("main:main")
|
||||||
|
|
||||||
|
# Set up tracking
|
||||||
|
git_repo.heads.main.set_tracking_branch(git_repo.remotes.origin.refs.main)
|
||||||
|
|
||||||
|
return git_repo, remote_repo
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def workspace_manager(temp_dir: Path, test_settings):
|
||||||
|
"""Create a WorkspaceManager with test settings."""
|
||||||
|
from workspace import WorkspaceManager
|
||||||
|
|
||||||
|
test_settings.workspace_base_path = temp_dir / "workspaces"
|
||||||
|
return WorkspaceManager(test_settings)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def git_wrapper(temp_workspace: Path, test_settings):
|
||||||
|
"""Create a GitWrapper for the temp workspace."""
|
||||||
|
from git_wrapper import GitWrapper
|
||||||
|
|
||||||
|
return GitWrapper(temp_workspace, test_settings)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def git_wrapper_with_repo(git_repo: GitRepo, test_settings):
|
||||||
|
"""Create a GitWrapper for a repo that's already initialized."""
|
||||||
|
from git_wrapper import GitWrapper
|
||||||
|
|
||||||
|
return GitWrapper(Path(git_repo.working_dir), test_settings)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_gitea_provider():
|
||||||
|
"""Create a mock Gitea provider."""
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.name = "gitea"
|
||||||
|
provider.is_connected = AsyncMock(return_value=True)
|
||||||
|
provider.get_authenticated_user = AsyncMock(return_value="test-user")
|
||||||
|
provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_httpx_client():
|
||||||
|
"""Create a mock httpx client for provider tests."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json = MagicMock(return_value={})
|
||||||
|
mock_response.text = ""
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.request = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.patch = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.delete = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def gitea_provider(test_settings, mock_httpx_client):
|
||||||
|
"""Create a GiteaProvider with mocked HTTP client."""
|
||||||
|
from providers.gitea import GiteaProvider
|
||||||
|
|
||||||
|
provider = GiteaProvider(
|
||||||
|
base_url=test_settings.gitea_base_url,
|
||||||
|
token=test_settings.gitea_token,
|
||||||
|
settings=test_settings,
|
||||||
|
)
|
||||||
|
provider._client = mock_httpx_client
|
||||||
|
|
||||||
|
yield provider
|
||||||
|
|
||||||
|
await provider.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_pr_data():
|
||||||
|
"""Sample PR data from Gitea API."""
|
||||||
|
return {
|
||||||
|
"number": 42,
|
||||||
|
"title": "Test PR",
|
||||||
|
"body": "This is a test pull request",
|
||||||
|
"state": "open",
|
||||||
|
"head": {"ref": "feature-branch"},
|
||||||
|
"base": {"ref": "main"},
|
||||||
|
"user": {"login": "test-user"},
|
||||||
|
"created_at": "2024-01-15T10:00:00Z",
|
||||||
|
"updated_at": "2024-01-15T12:00:00Z",
|
||||||
|
"merged_at": None,
|
||||||
|
"closed_at": None,
|
||||||
|
"html_url": "https://gitea.test.com/owner/repo/pull/42",
|
||||||
|
"labels": [{"name": "enhancement"}],
|
||||||
|
"assignees": [{"login": "assignee1"}],
|
||||||
|
"requested_reviewers": [{"login": "reviewer1"}],
|
||||||
|
"mergeable": True,
|
||||||
|
"draft": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_commit_data():
|
||||||
|
"""Sample commit data."""
|
||||||
|
return {
|
||||||
|
"sha": "abc123def456",
|
||||||
|
"short_sha": "abc123d",
|
||||||
|
"message": "Test commit message",
|
||||||
|
"author": {
|
||||||
|
"name": "Test Author",
|
||||||
|
"email": "author@test.com",
|
||||||
|
"date": "2024-01-15T10:00:00Z",
|
||||||
|
},
|
||||||
|
"committer": {
|
||||||
|
"name": "Test Committer",
|
||||||
|
"email": "committer@test.com",
|
||||||
|
"date": "2024-01-15T10:00:00Z",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_fastapi_app():
|
||||||
|
"""Create a test FastAPI app."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health():
|
||||||
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
# Async fixtures
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def async_workspace_manager(
|
||||||
|
temp_dir: Path, test_settings
|
||||||
|
) -> AsyncIterator:
|
||||||
|
"""Async fixture for workspace manager."""
|
||||||
|
from workspace import WorkspaceManager
|
||||||
|
|
||||||
|
test_settings.workspace_base_path = temp_dir / "workspaces"
|
||||||
|
manager = WorkspaceManager(test_settings)
|
||||||
|
yield manager
|
||||||
|
|
||||||
|
|
||||||
|
# Test data fixtures
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_project_id() -> str:
|
||||||
|
"""Valid project ID for tests."""
|
||||||
|
return "test-project-123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_agent_id() -> str:
|
||||||
|
"""Valid agent ID for tests."""
|
||||||
|
return "agent-456"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def invalid_ids() -> list[str]:
|
||||||
|
"""Invalid IDs for validation tests."""
|
||||||
|
return [
|
||||||
|
"",
|
||||||
|
" ",
|
||||||
|
"a" * 200, # Too long
|
||||||
|
"test@invalid", # Invalid character
|
||||||
|
"test!invalid",
|
||||||
|
"../path/traversal",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_repo_url() -> str:
|
||||||
|
"""Sample repository URL."""
|
||||||
|
return "https://gitea.test.com/owner/repo.git"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_ssh_repo_url() -> str:
|
||||||
|
"""Sample SSH repository URL."""
|
||||||
|
return "git@gitea.test.com:owner/repo.git"
|
||||||
434
mcp-servers/git-ops/tests/test_git_wrapper.py
Normal file
434
mcp-servers/git-ops/tests/test_git_wrapper.py
Normal file
@@ -0,0 +1,434 @@
|
|||||||
|
"""
|
||||||
|
Tests for the git_wrapper module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from exceptions import (
|
||||||
|
BranchExistsError,
|
||||||
|
BranchNotFoundError,
|
||||||
|
CheckoutError,
|
||||||
|
CommitError,
|
||||||
|
GitError,
|
||||||
|
)
|
||||||
|
from git_wrapper import GitWrapper
|
||||||
|
from models import FileChangeType
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperInit:
|
||||||
|
"""Tests for GitWrapper initialization."""
|
||||||
|
|
||||||
|
def test_init_with_valid_path(self, temp_workspace, test_settings):
|
||||||
|
"""Test initialization with a valid path."""
|
||||||
|
wrapper = GitWrapper(temp_workspace, test_settings)
|
||||||
|
assert wrapper.workspace_path == temp_workspace
|
||||||
|
assert wrapper.settings == test_settings
|
||||||
|
|
||||||
|
def test_repo_property_raises_on_non_git(self, temp_workspace, test_settings):
|
||||||
|
"""Test that accessing repo on non-git dir raises error."""
|
||||||
|
wrapper = GitWrapper(temp_workspace, test_settings)
|
||||||
|
with pytest.raises(GitError, match="Not a git repository"):
|
||||||
|
_ = wrapper.repo
|
||||||
|
|
||||||
|
def test_repo_property_works_on_git_dir(self, git_repo, test_settings):
|
||||||
|
"""Test that repo property works for git directory."""
|
||||||
|
wrapper = GitWrapper(Path(git_repo.working_dir), test_settings)
|
||||||
|
assert wrapper.repo is not None
|
||||||
|
assert wrapper.repo.head is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperStatus:
|
||||||
|
"""Tests for git status operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_status_clean_repo(self, git_wrapper_with_repo):
|
||||||
|
"""Test status on a clean repository."""
|
||||||
|
result = await git_wrapper_with_repo.status()
|
||||||
|
|
||||||
|
assert result.branch == "main"
|
||||||
|
assert result.is_clean is True
|
||||||
|
assert len(result.staged) == 0
|
||||||
|
assert len(result.unstaged) == 0
|
||||||
|
assert len(result.untracked) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_status_with_untracked(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test status with untracked files."""
|
||||||
|
# Create untracked file
|
||||||
|
untracked_file = Path(git_repo.working_dir) / "untracked.txt"
|
||||||
|
untracked_file.write_text("untracked content")
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.status()
|
||||||
|
|
||||||
|
assert result.is_clean is False
|
||||||
|
assert "untracked.txt" in result.untracked
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_status_with_modified(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test status with modified files."""
|
||||||
|
# Modify existing file
|
||||||
|
readme = Path(git_repo.working_dir) / "README.md"
|
||||||
|
readme.write_text("# Modified content\n")
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.status()
|
||||||
|
|
||||||
|
assert result.is_clean is False
|
||||||
|
assert len(result.unstaged) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_status_with_staged(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test status with staged changes."""
|
||||||
|
# Create and stage a file
|
||||||
|
new_file = Path(git_repo.working_dir) / "staged.txt"
|
||||||
|
new_file.write_text("staged content")
|
||||||
|
git_repo.index.add(["staged.txt"])
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.status()
|
||||||
|
|
||||||
|
assert result.is_clean is False
|
||||||
|
assert len(result.staged) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_status_exclude_untracked(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test status without untracked files."""
|
||||||
|
untracked_file = Path(git_repo.working_dir) / "untracked.txt"
|
||||||
|
untracked_file.write_text("untracked")
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.status(include_untracked=False)
|
||||||
|
|
||||||
|
assert len(result.untracked) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperBranch:
|
||||||
|
"""Tests for branch operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_branch(self, git_wrapper_with_repo):
|
||||||
|
"""Test creating a new branch."""
|
||||||
|
result = await git_wrapper_with_repo.create_branch("feature-test")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.branch == "feature-test"
|
||||||
|
assert result.is_current is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_branch_without_checkout(self, git_wrapper_with_repo):
|
||||||
|
"""Test creating branch without checkout."""
|
||||||
|
result = await git_wrapper_with_repo.create_branch("feature-no-checkout", checkout=False)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.branch == "feature-no-checkout"
|
||||||
|
assert result.is_current is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_branch_exists_error(self, git_wrapper_with_repo):
|
||||||
|
"""Test error when branch already exists."""
|
||||||
|
await git_wrapper_with_repo.create_branch("existing-branch", checkout=False)
|
||||||
|
|
||||||
|
with pytest.raises(BranchExistsError):
|
||||||
|
await git_wrapper_with_repo.create_branch("existing-branch")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_branch(self, git_wrapper_with_repo):
|
||||||
|
"""Test deleting a branch."""
|
||||||
|
# Create branch first
|
||||||
|
await git_wrapper_with_repo.create_branch("to-delete", checkout=False)
|
||||||
|
|
||||||
|
# Delete it
|
||||||
|
result = await git_wrapper_with_repo.delete_branch("to-delete")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.branch == "to-delete"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_branch_not_found(self, git_wrapper_with_repo):
|
||||||
|
"""Test error when deleting non-existent branch."""
|
||||||
|
with pytest.raises(BranchNotFoundError):
|
||||||
|
await git_wrapper_with_repo.delete_branch("nonexistent")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_current_branch_error(self, git_wrapper_with_repo):
|
||||||
|
"""Test error when deleting current branch."""
|
||||||
|
with pytest.raises(GitError, match="Cannot delete current branch"):
|
||||||
|
await git_wrapper_with_repo.delete_branch("main")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_branches(self, git_wrapper_with_repo):
|
||||||
|
"""Test listing branches."""
|
||||||
|
# Create some branches
|
||||||
|
await git_wrapper_with_repo.create_branch("branch-a", checkout=False)
|
||||||
|
await git_wrapper_with_repo.create_branch("branch-b", checkout=False)
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.list_branches()
|
||||||
|
|
||||||
|
assert result.current_branch == "main"
|
||||||
|
branch_names = [b["name"] for b in result.local_branches]
|
||||||
|
assert "main" in branch_names
|
||||||
|
assert "branch-a" in branch_names
|
||||||
|
assert "branch-b" in branch_names
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperCheckout:
|
||||||
|
"""Tests for checkout operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_existing_branch(self, git_wrapper_with_repo):
|
||||||
|
"""Test checkout of existing branch."""
|
||||||
|
# Create branch first
|
||||||
|
await git_wrapper_with_repo.create_branch("test-branch", checkout=False)
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.checkout("test-branch")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.ref == "test-branch"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_create_new(self, git_wrapper_with_repo):
|
||||||
|
"""Test checkout with branch creation."""
|
||||||
|
result = await git_wrapper_with_repo.checkout("new-branch", create_branch=True)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.ref == "new-branch"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_nonexistent_error(self, git_wrapper_with_repo):
|
||||||
|
"""Test error when checking out non-existent ref."""
|
||||||
|
with pytest.raises(CheckoutError):
|
||||||
|
await git_wrapper_with_repo.checkout("nonexistent-branch")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperCommit:
|
||||||
|
"""Tests for commit operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_commit_staged_changes(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test committing staged changes."""
|
||||||
|
# Create and stage a file
|
||||||
|
new_file = Path(git_repo.working_dir) / "newfile.txt"
|
||||||
|
new_file.write_text("new content")
|
||||||
|
git_repo.index.add(["newfile.txt"])
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.commit("Add new file")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.message == "Add new file"
|
||||||
|
assert result.files_changed == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_commit_all_changes(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test committing all changes (auto-stage)."""
|
||||||
|
# Create a file without staging
|
||||||
|
new_file = Path(git_repo.working_dir) / "unstaged.txt"
|
||||||
|
new_file.write_text("content")
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.commit("Commit unstaged")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_commit_nothing_to_commit(self, git_wrapper_with_repo):
|
||||||
|
"""Test error when nothing to commit."""
|
||||||
|
with pytest.raises(CommitError, match="Nothing to commit"):
|
||||||
|
await git_wrapper_with_repo.commit("Empty commit")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_commit_with_author(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test commit with custom author."""
|
||||||
|
new_file = Path(git_repo.working_dir) / "authored.txt"
|
||||||
|
new_file.write_text("authored content")
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.commit(
|
||||||
|
"Custom author commit",
|
||||||
|
author_name="Custom Author",
|
||||||
|
author_email="custom@test.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperDiff:
|
||||||
|
"""Tests for diff operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diff_no_changes(self, git_wrapper_with_repo):
|
||||||
|
"""Test diff with no changes."""
|
||||||
|
result = await git_wrapper_with_repo.diff()
|
||||||
|
|
||||||
|
assert result.files_changed == 0
|
||||||
|
assert result.total_additions == 0
|
||||||
|
assert result.total_deletions == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diff_with_changes(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test diff with modified files."""
|
||||||
|
# Modify a file
|
||||||
|
readme = Path(git_repo.working_dir) / "README.md"
|
||||||
|
readme.write_text("# Modified\nNew line\n")
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.diff()
|
||||||
|
|
||||||
|
assert result.files_changed > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperLog:
|
||||||
|
"""Tests for log operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_basic(self, git_wrapper_with_repo):
|
||||||
|
"""Test basic log."""
|
||||||
|
result = await git_wrapper_with_repo.log()
|
||||||
|
|
||||||
|
assert result.total_commits > 0
|
||||||
|
assert len(result.commits) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_with_limit(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test log with limit."""
|
||||||
|
# Create more commits
|
||||||
|
for i in range(5):
|
||||||
|
file_path = Path(git_repo.working_dir) / f"file{i}.txt"
|
||||||
|
file_path.write_text(f"content {i}")
|
||||||
|
git_repo.index.add([f"file{i}.txt"])
|
||||||
|
git_repo.index.commit(f"Commit {i}")
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.log(limit=3)
|
||||||
|
|
||||||
|
assert len(result.commits) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_commit_info(self, git_wrapper_with_repo):
|
||||||
|
"""Test that log returns proper commit info."""
|
||||||
|
result = await git_wrapper_with_repo.log(limit=1)
|
||||||
|
|
||||||
|
commit = result.commits[0]
|
||||||
|
assert "sha" in commit
|
||||||
|
assert "message" in commit
|
||||||
|
assert "author_name" in commit
|
||||||
|
assert "author_email" in commit
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperUtilities:
|
||||||
|
"""Tests for utility methods."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_valid_ref_true(self, git_wrapper_with_repo):
|
||||||
|
"""Test valid ref detection."""
|
||||||
|
is_valid = await git_wrapper_with_repo.is_valid_ref("main")
|
||||||
|
assert is_valid is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_valid_ref_false(self, git_wrapper_with_repo):
|
||||||
|
"""Test invalid ref detection."""
|
||||||
|
is_valid = await git_wrapper_with_repo.is_valid_ref("nonexistent")
|
||||||
|
assert is_valid is False
|
||||||
|
|
||||||
|
def test_diff_to_change_type(self, git_wrapper_with_repo):
|
||||||
|
"""Test change type conversion."""
|
||||||
|
wrapper = git_wrapper_with_repo
|
||||||
|
|
||||||
|
assert wrapper._diff_to_change_type("A") == FileChangeType.ADDED
|
||||||
|
assert wrapper._diff_to_change_type("M") == FileChangeType.MODIFIED
|
||||||
|
assert wrapper._diff_to_change_type("D") == FileChangeType.DELETED
|
||||||
|
assert wrapper._diff_to_change_type("R") == FileChangeType.RENAMED
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperStage:
|
||||||
|
"""Tests for staging operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stage_specific_files(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test staging specific files."""
|
||||||
|
# Create files
|
||||||
|
file1 = Path(git_repo.working_dir) / "file1.txt"
|
||||||
|
file2 = Path(git_repo.working_dir) / "file2.txt"
|
||||||
|
file1.write_text("content 1")
|
||||||
|
file2.write_text("content 2")
|
||||||
|
|
||||||
|
count = await git_wrapper_with_repo.stage(["file1.txt"])
|
||||||
|
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stage_all(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test staging all files."""
|
||||||
|
file1 = Path(git_repo.working_dir) / "all1.txt"
|
||||||
|
file2 = Path(git_repo.working_dir) / "all2.txt"
|
||||||
|
file1.write_text("content 1")
|
||||||
|
file2.write_text("content 2")
|
||||||
|
|
||||||
|
count = await git_wrapper_with_repo.stage()
|
||||||
|
|
||||||
|
assert count >= 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unstage_files(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test unstaging files."""
|
||||||
|
# Create and stage file
|
||||||
|
file1 = Path(git_repo.working_dir) / "unstage.txt"
|
||||||
|
file1.write_text("to unstage")
|
||||||
|
git_repo.index.add(["unstage.txt"])
|
||||||
|
|
||||||
|
count = await git_wrapper_with_repo.unstage()
|
||||||
|
|
||||||
|
assert count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperReset:
|
||||||
|
"""Tests for reset operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_soft(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test soft reset."""
|
||||||
|
# Create a commit to reset
|
||||||
|
file1 = Path(git_repo.working_dir) / "reset_soft.txt"
|
||||||
|
file1.write_text("content")
|
||||||
|
git_repo.index.add(["reset_soft.txt"])
|
||||||
|
git_repo.index.commit("Commit to reset")
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.reset("HEAD~1", mode="soft")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_mixed(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test mixed reset (default)."""
|
||||||
|
file1 = Path(git_repo.working_dir) / "reset_mixed.txt"
|
||||||
|
file1.write_text("content")
|
||||||
|
git_repo.index.add(["reset_mixed.txt"])
|
||||||
|
git_repo.index.commit("Commit to reset")
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.reset("HEAD~1", mode="mixed")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reset_invalid_mode(self, git_wrapper_with_repo):
|
||||||
|
"""Test error on invalid reset mode."""
|
||||||
|
with pytest.raises(GitError, match="Invalid reset mode"):
|
||||||
|
await git_wrapper_with_repo.reset("HEAD", mode="invalid")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitWrapperStash:
|
||||||
|
"""Tests for stash operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stash_changes(self, git_wrapper_with_repo, git_repo):
|
||||||
|
"""Test stashing changes."""
|
||||||
|
# Make changes
|
||||||
|
readme = Path(git_repo.working_dir) / "README.md"
|
||||||
|
readme.write_text("Modified for stash")
|
||||||
|
|
||||||
|
result = await git_wrapper_with_repo.stash("Test stash")
|
||||||
|
|
||||||
|
# Result should be stash ref or None if nothing to stash
|
||||||
|
# (depends on whether changes were already staged)
|
||||||
|
assert result is None or result.startswith("stash@")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stash_nothing(self, git_wrapper_with_repo):
|
||||||
|
"""Test stash with no changes."""
|
||||||
|
result = await git_wrapper_with_repo.stash()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
583
mcp-servers/git-ops/tests/test_github_provider.py
Normal file
583
mcp-servers/git-ops/tests/test_github_provider.py
Normal file
@@ -0,0 +1,583 @@
|
|||||||
|
"""
|
||||||
|
Tests for GitHub provider implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from exceptions import APIError, AuthenticationError
|
||||||
|
from models import MergeStrategy, PRState
|
||||||
|
from providers.github import GitHubProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubProviderBasics:
|
||||||
|
"""Tests for GitHubProvider basic operations."""
|
||||||
|
|
||||||
|
def test_provider_name(self):
|
||||||
|
"""Test provider name is github."""
|
||||||
|
provider = GitHubProvider(token="test-token")
|
||||||
|
assert provider.name == "github"
|
||||||
|
|
||||||
|
def test_parse_repo_url_https(self):
|
||||||
|
"""Test parsing HTTPS repo URL."""
|
||||||
|
provider = GitHubProvider(token="test-token")
|
||||||
|
|
||||||
|
owner, repo = provider.parse_repo_url("https://github.com/owner/repo.git")
|
||||||
|
|
||||||
|
assert owner == "owner"
|
||||||
|
assert repo == "repo"
|
||||||
|
|
||||||
|
def test_parse_repo_url_https_no_git(self):
|
||||||
|
"""Test parsing HTTPS URL without .git suffix."""
|
||||||
|
provider = GitHubProvider(token="test-token")
|
||||||
|
|
||||||
|
owner, repo = provider.parse_repo_url("https://github.com/owner/repo")
|
||||||
|
|
||||||
|
assert owner == "owner"
|
||||||
|
assert repo == "repo"
|
||||||
|
|
||||||
|
def test_parse_repo_url_ssh(self):
|
||||||
|
"""Test parsing SSH repo URL."""
|
||||||
|
provider = GitHubProvider(token="test-token")
|
||||||
|
|
||||||
|
owner, repo = provider.parse_repo_url("git@github.com:owner/repo.git")
|
||||||
|
|
||||||
|
assert owner == "owner"
|
||||||
|
assert repo == "repo"
|
||||||
|
|
||||||
|
def test_parse_repo_url_invalid(self):
|
||||||
|
"""Test error on invalid URL."""
|
||||||
|
provider = GitHubProvider(token="test-token")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unable to parse"):
|
||||||
|
provider.parse_repo_url("invalid-url")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_github_httpx_client():
|
||||||
|
"""Create a mock httpx client for GitHub provider tests."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json = MagicMock(return_value={})
|
||||||
|
mock_response.text = ""
|
||||||
|
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.request = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.patch = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.put = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.delete = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def github_provider(test_settings, mock_github_httpx_client):
|
||||||
|
"""Create a GitHubProvider with mocked HTTP client."""
|
||||||
|
provider = GitHubProvider(
|
||||||
|
token=test_settings.github_token,
|
||||||
|
settings=test_settings,
|
||||||
|
)
|
||||||
|
provider._client = mock_github_httpx_client
|
||||||
|
|
||||||
|
yield provider
|
||||||
|
|
||||||
|
await provider.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def github_pr_data():
|
||||||
|
"""Sample PR data from GitHub API."""
|
||||||
|
return {
|
||||||
|
"number": 42,
|
||||||
|
"title": "Test PR",
|
||||||
|
"body": "This is a test pull request",
|
||||||
|
"state": "open",
|
||||||
|
"head": {"ref": "feature-branch"},
|
||||||
|
"base": {"ref": "main"},
|
||||||
|
"user": {"login": "test-user"},
|
||||||
|
"created_at": "2024-01-15T10:00:00Z",
|
||||||
|
"updated_at": "2024-01-15T12:00:00Z",
|
||||||
|
"merged_at": None,
|
||||||
|
"closed_at": None,
|
||||||
|
"html_url": "https://github.com/owner/repo/pull/42",
|
||||||
|
"labels": [{"name": "enhancement"}],
|
||||||
|
"assignees": [{"login": "assignee1"}],
|
||||||
|
"requested_reviewers": [{"login": "reviewer1"}],
|
||||||
|
"mergeable": True,
|
||||||
|
"draft": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubProviderConnection:
|
||||||
|
"""Tests for GitHub provider connection."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_connected(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test connection check."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"login": "test-user"}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.is_connected()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_connected_no_token(self, test_settings):
|
||||||
|
"""Test connection fails without token."""
|
||||||
|
provider = GitHubProvider(
|
||||||
|
token="",
|
||||||
|
settings=test_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await provider.is_connected()
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
await provider.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_authenticated_user(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test getting authenticated user."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"login": "test-user"}
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await github_provider.get_authenticated_user()
|
||||||
|
|
||||||
|
assert user == "test-user"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubProviderRepoOperations:
|
||||||
|
"""Tests for GitHub repository operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_repo_info(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test getting repository info."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={
|
||||||
|
"name": "repo",
|
||||||
|
"full_name": "owner/repo",
|
||||||
|
"default_branch": "main",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.get_repo_info("owner", "repo")
|
||||||
|
|
||||||
|
assert result["name"] == "repo"
|
||||||
|
assert result["default_branch"] == "main"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_default_branch(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test getting default branch."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"default_branch": "develop"}
|
||||||
|
)
|
||||||
|
|
||||||
|
branch = await github_provider.get_default_branch("owner", "repo")
|
||||||
|
|
||||||
|
assert branch == "develop"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubPROperations:
|
||||||
|
"""Tests for GitHub PR operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_pr(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test creating a pull request."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={
|
||||||
|
"number": 42,
|
||||||
|
"html_url": "https://github.com/owner/repo/pull/42",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.create_pr(
|
||||||
|
owner="owner",
|
||||||
|
repo="repo",
|
||||||
|
title="Test PR",
|
||||||
|
body="Test body",
|
||||||
|
source_branch="feature",
|
||||||
|
target_branch="main",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.pr_number == 42
|
||||||
|
assert result.pr_url == "https://github.com/owner/repo/pull/42"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_pr_with_draft(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test creating a draft PR."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={
|
||||||
|
"number": 43,
|
||||||
|
"html_url": "https://github.com/owner/repo/pull/43",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.create_pr(
|
||||||
|
owner="owner",
|
||||||
|
repo="repo",
|
||||||
|
title="Draft PR",
|
||||||
|
body="Draft body",
|
||||||
|
source_branch="feature",
|
||||||
|
target_branch="main",
|
||||||
|
draft=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.pr_number == 43
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_pr_with_options(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test creating PR with labels, assignees, reviewers."""
|
||||||
|
mock_responses = [
|
||||||
|
{"number": 44, "html_url": "https://github.com/owner/repo/pull/44"}, # Create PR
|
||||||
|
[{"name": "enhancement"}], # POST add labels
|
||||||
|
{}, # POST add assignees
|
||||||
|
{}, # POST request reviewers
|
||||||
|
]
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
|
||||||
|
|
||||||
|
result = await github_provider.create_pr(
|
||||||
|
owner="owner",
|
||||||
|
repo="repo",
|
||||||
|
title="Test PR",
|
||||||
|
body="Test body",
|
||||||
|
source_branch="feature",
|
||||||
|
target_branch="main",
|
||||||
|
labels=["enhancement"],
|
||||||
|
assignees=["user1"],
|
||||||
|
reviewers=["reviewer1"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||||
|
"""Test getting a pull request."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=github_pr_data
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.get_pr("owner", "repo", 42)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.pr["number"] == 42
|
||||||
|
assert result.pr["title"] == "Test PR"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pr_not_found(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test getting non-existent PR."""
|
||||||
|
mock_github_httpx_client.request.return_value.status_code = 404
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
result = await github_provider.get_pr("owner", "repo", 999)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_prs(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||||
|
"""Test listing pull requests."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=[github_pr_data, github_pr_data]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.list_prs("owner", "repo")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert len(result.pull_requests) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_prs_with_state_filter(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||||
|
"""Test listing PRs with state filter."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=[github_pr_data]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.list_prs(
|
||||||
|
"owner", "repo", state=PRState.OPEN
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_merge_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||||
|
"""Test merging a pull request."""
|
||||||
|
# Merge returns sha, then get_pr returns the PR data, then delete branch
|
||||||
|
mock_responses = [
|
||||||
|
{"sha": "merge-commit-sha", "merged": True}, # PUT merge
|
||||||
|
github_pr_data, # GET PR for branch info
|
||||||
|
None, # DELETE branch
|
||||||
|
]
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
side_effect=mock_responses
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.merge_pr(
|
||||||
|
"owner", "repo", 42,
|
||||||
|
merge_strategy=MergeStrategy.SQUASH,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.merge_commit_sha == "merge-commit-sha"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_merge_pr_rebase(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||||
|
"""Test merging with rebase strategy."""
|
||||||
|
mock_responses = [
|
||||||
|
{"sha": "rebase-commit-sha", "merged": True}, # PUT merge
|
||||||
|
github_pr_data, # GET PR for branch info
|
||||||
|
None, # DELETE branch
|
||||||
|
]
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
side_effect=mock_responses
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.merge_pr(
|
||||||
|
"owner", "repo", 42,
|
||||||
|
merge_strategy=MergeStrategy.REBASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||||
|
"""Test updating a pull request."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=github_pr_data
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.update_pr(
|
||||||
|
"owner", "repo", 42,
|
||||||
|
title="Updated Title",
|
||||||
|
body="Updated body",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
|
||||||
|
"""Test closing a pull request."""
|
||||||
|
github_pr_data["state"] = "closed"
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=github_pr_data
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.close_pr("owner", "repo", 42)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubBranchOperations:
|
||||||
|
"""Tests for GitHub branch operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_branch(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test getting branch info."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={
|
||||||
|
"name": "main",
|
||||||
|
"commit": {"sha": "abc123"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.get_branch("owner", "repo", "main")
|
||||||
|
|
||||||
|
assert result["name"] == "main"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_remote_branch(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test deleting a remote branch."""
|
||||||
|
mock_github_httpx_client.request.return_value.status_code = 204
|
||||||
|
|
||||||
|
result = await github_provider.delete_remote_branch("owner", "repo", "old-branch")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubCommentOperations:
|
||||||
|
"""Tests for GitHub comment operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_pr_comment(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test adding a comment to a PR."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"id": 1, "body": "Test comment"}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.add_pr_comment(
|
||||||
|
"owner", "repo", 42, "Test comment"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["body"] == "Test comment"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_pr_comments(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test listing PR comments."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=[
|
||||||
|
{"id": 1, "body": "Comment 1"},
|
||||||
|
{"id": 2, "body": "Comment 2"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.list_pr_comments("owner", "repo", 42)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubLabelOperations:
|
||||||
|
"""Tests for GitHub label operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_labels(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test adding labels to a PR."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=[{"name": "bug"}, {"name": "urgent"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await github_provider.add_labels(
|
||||||
|
"owner", "repo", 42, ["bug", "urgent"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "bug" in result
|
||||||
|
assert "urgent" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_label(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test removing a label from a PR."""
|
||||||
|
mock_responses = [
|
||||||
|
None, # DELETE label
|
||||||
|
{"labels": []}, # GET issue
|
||||||
|
]
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
|
||||||
|
|
||||||
|
result = await github_provider.remove_label(
|
||||||
|
"owner", "repo", 42, "bug"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubReviewerOperations:
|
||||||
|
"""Tests for GitHub reviewer operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_review(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test requesting review from users."""
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(return_value={})
|
||||||
|
|
||||||
|
result = await github_provider.request_review(
|
||||||
|
"owner", "repo", 42, ["reviewer1", "reviewer2"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == ["reviewer1", "reviewer2"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubErrorHandling:
|
||||||
|
"""Tests for error handling in GitHub provider."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authentication_error(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test handling authentication errors."""
|
||||||
|
mock_github_httpx_client.request.return_value.status_code = 401
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
await github_provider._request("GET", "/user")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_permission_denied(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test handling permission denied errors."""
|
||||||
|
mock_github_httpx_client.request.return_value.status_code = 403
|
||||||
|
mock_github_httpx_client.request.return_value.text = "Permission denied"
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
|
||||||
|
await github_provider._request("GET", "/protected")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rate_limit_error(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test handling rate limit errors."""
|
||||||
|
mock_github_httpx_client.request.return_value.status_code = 403
|
||||||
|
mock_github_httpx_client.request.return_value.text = "API rate limit exceeded"
|
||||||
|
|
||||||
|
with pytest.raises(APIError, match="rate limit"):
|
||||||
|
await github_provider._request("GET", "/user")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_error(self, github_provider, mock_github_httpx_client):
|
||||||
|
"""Test handling general API errors."""
|
||||||
|
mock_github_httpx_client.request.return_value.status_code = 500
|
||||||
|
mock_github_httpx_client.request.return_value.text = "Internal Server Error"
|
||||||
|
mock_github_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"message": "Server error"}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(APIError):
|
||||||
|
await github_provider._request("GET", "/error")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitHubPRParsing:
|
||||||
|
"""Tests for PR data parsing."""
|
||||||
|
|
||||||
|
def test_parse_pr_open(self, github_provider, github_pr_data):
|
||||||
|
"""Test parsing open PR."""
|
||||||
|
pr_info = github_provider._parse_pr(github_pr_data)
|
||||||
|
|
||||||
|
assert pr_info.number == 42
|
||||||
|
assert pr_info.state == PRState.OPEN
|
||||||
|
assert pr_info.title == "Test PR"
|
||||||
|
assert pr_info.source_branch == "feature-branch"
|
||||||
|
assert pr_info.target_branch == "main"
|
||||||
|
|
||||||
|
def test_parse_pr_merged(self, github_provider, github_pr_data):
|
||||||
|
"""Test parsing merged PR."""
|
||||||
|
github_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
|
||||||
|
|
||||||
|
pr_info = github_provider._parse_pr(github_pr_data)
|
||||||
|
|
||||||
|
assert pr_info.state == PRState.MERGED
|
||||||
|
|
||||||
|
def test_parse_pr_closed(self, github_provider, github_pr_data):
|
||||||
|
"""Test parsing closed PR."""
|
||||||
|
github_pr_data["state"] = "closed"
|
||||||
|
github_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
|
||||||
|
|
||||||
|
pr_info = github_provider._parse_pr(github_pr_data)
|
||||||
|
|
||||||
|
assert pr_info.state == PRState.CLOSED
|
||||||
|
|
||||||
|
def test_parse_pr_draft(self, github_provider, github_pr_data):
|
||||||
|
"""Test parsing draft PR."""
|
||||||
|
github_pr_data["draft"] = True
|
||||||
|
|
||||||
|
pr_info = github_provider._parse_pr(github_pr_data)
|
||||||
|
|
||||||
|
assert pr_info.draft is True
|
||||||
|
|
||||||
|
def test_parse_datetime_iso(self, github_provider):
|
||||||
|
"""Test parsing ISO datetime strings."""
|
||||||
|
dt = github_provider._parse_datetime("2024-01-15T10:30:00Z")
|
||||||
|
|
||||||
|
assert dt.year == 2024
|
||||||
|
assert dt.month == 1
|
||||||
|
assert dt.day == 15
|
||||||
|
|
||||||
|
def test_parse_datetime_none(self, github_provider):
|
||||||
|
"""Test parsing None datetime returns now."""
|
||||||
|
dt = github_provider._parse_datetime(None)
|
||||||
|
|
||||||
|
assert dt is not None
|
||||||
|
assert dt.tzinfo is not None
|
||||||
|
|
||||||
|
def test_parse_pr_with_null_body(self, github_provider, github_pr_data):
|
||||||
|
"""Test parsing PR with null body."""
|
||||||
|
github_pr_data["body"] = None
|
||||||
|
|
||||||
|
pr_info = github_provider._parse_pr(github_pr_data)
|
||||||
|
|
||||||
|
assert pr_info.body == ""
|
||||||
484
mcp-servers/git-ops/tests/test_providers.py
Normal file
484
mcp-servers/git-ops/tests/test_providers.py
Normal file
@@ -0,0 +1,484 @@
|
|||||||
|
"""
|
||||||
|
Tests for git provider implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from exceptions import APIError, AuthenticationError
|
||||||
|
from models import MergeStrategy, PRState
|
||||||
|
from providers.gitea import GiteaProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestBaseProvider:
|
||||||
|
"""Tests for BaseProvider interface."""
|
||||||
|
|
||||||
|
def test_parse_repo_url_https(self, mock_gitea_provider):
|
||||||
|
"""Test parsing HTTPS repo URL."""
|
||||||
|
# The mock needs parse_repo_url to work
|
||||||
|
provider = GiteaProvider(
|
||||||
|
base_url="https://gitea.test.com",
|
||||||
|
token="test-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
owner, repo = provider.parse_repo_url("https://gitea.test.com/owner/repo.git")
|
||||||
|
|
||||||
|
assert owner == "owner"
|
||||||
|
assert repo == "repo"
|
||||||
|
|
||||||
|
def test_parse_repo_url_https_no_git(self):
|
||||||
|
"""Test parsing HTTPS URL without .git suffix."""
|
||||||
|
provider = GiteaProvider(
|
||||||
|
base_url="https://gitea.test.com",
|
||||||
|
token="test-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
owner, repo = provider.parse_repo_url("https://gitea.test.com/owner/repo")
|
||||||
|
|
||||||
|
assert owner == "owner"
|
||||||
|
assert repo == "repo"
|
||||||
|
|
||||||
|
def test_parse_repo_url_ssh(self):
|
||||||
|
"""Test parsing SSH repo URL."""
|
||||||
|
provider = GiteaProvider(
|
||||||
|
base_url="https://gitea.test.com",
|
||||||
|
token="test-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
owner, repo = provider.parse_repo_url("git@gitea.test.com:owner/repo.git")
|
||||||
|
|
||||||
|
assert owner == "owner"
|
||||||
|
assert repo == "repo"
|
||||||
|
|
||||||
|
def test_parse_repo_url_invalid(self):
|
||||||
|
"""Test error on invalid URL."""
|
||||||
|
provider = GiteaProvider(
|
||||||
|
base_url="https://gitea.test.com",
|
||||||
|
token="test-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unable to parse"):
|
||||||
|
provider.parse_repo_url("invalid-url")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGiteaProvider:
|
||||||
|
"""Tests for GiteaProvider."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_connected(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test connection check."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"login": "test-user"}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.is_connected()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_connected_no_token(self, test_settings):
|
||||||
|
"""Test connection fails without token."""
|
||||||
|
provider = GiteaProvider(
|
||||||
|
base_url="https://gitea.test.com",
|
||||||
|
token="",
|
||||||
|
settings=test_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await provider.is_connected()
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
await provider.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_authenticated_user(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test getting authenticated user."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"login": "test-user"}
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await gitea_provider.get_authenticated_user()
|
||||||
|
|
||||||
|
assert user == "test-user"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_repo_info(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test getting repository info."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={
|
||||||
|
"name": "repo",
|
||||||
|
"full_name": "owner/repo",
|
||||||
|
"default_branch": "main",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.get_repo_info("owner", "repo")
|
||||||
|
|
||||||
|
assert result["name"] == "repo"
|
||||||
|
assert result["default_branch"] == "main"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_default_branch(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test getting default branch."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"default_branch": "develop"}
|
||||||
|
)
|
||||||
|
|
||||||
|
branch = await gitea_provider.get_default_branch("owner", "repo")
|
||||||
|
|
||||||
|
assert branch == "develop"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGiteaPROperations:
|
||||||
|
"""Tests for Gitea PR operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_pr(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test creating a pull request."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={
|
||||||
|
"number": 42,
|
||||||
|
"html_url": "https://gitea.test.com/owner/repo/pull/42",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.create_pr(
|
||||||
|
owner="owner",
|
||||||
|
repo="repo",
|
||||||
|
title="Test PR",
|
||||||
|
body="Test body",
|
||||||
|
source_branch="feature",
|
||||||
|
target_branch="main",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.pr_number == 42
|
||||||
|
assert result.pr_url == "https://gitea.test.com/owner/repo/pull/42"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_pr_with_options(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test creating PR with labels, assignees, reviewers."""
|
||||||
|
# Use side_effect for multiple API calls:
|
||||||
|
# 1. POST create PR
|
||||||
|
# 2. GET labels (for "enhancement") - in add_labels -> _get_or_create_label
|
||||||
|
# 3. POST add labels to PR - in add_labels
|
||||||
|
# 4. GET issue to return labels - in add_labels
|
||||||
|
# 5. PATCH add assignees
|
||||||
|
# 6. POST request reviewers
|
||||||
|
mock_responses = [
|
||||||
|
{"number": 43, "html_url": "https://gitea.test.com/owner/repo/pull/43"}, # Create PR
|
||||||
|
[{"id": 1, "name": "enhancement"}], # GET labels (found)
|
||||||
|
{}, # POST add labels to PR
|
||||||
|
{"labels": [{"name": "enhancement"}]}, # GET issue to return current labels
|
||||||
|
{}, # PATCH add assignees
|
||||||
|
{}, # POST request reviewers
|
||||||
|
]
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
|
||||||
|
|
||||||
|
result = await gitea_provider.create_pr(
|
||||||
|
owner="owner",
|
||||||
|
repo="repo",
|
||||||
|
title="Test PR",
|
||||||
|
body="Test body",
|
||||||
|
source_branch="feature",
|
||||||
|
target_branch="main",
|
||||||
|
labels=["enhancement"],
|
||||||
|
assignees=["user1"],
|
||||||
|
reviewers=["reviewer1"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||||
|
"""Test getting a pull request."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=sample_pr_data
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.get_pr("owner", "repo", 42)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.pr["number"] == 42
|
||||||
|
assert result.pr["title"] == "Test PR"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pr_not_found(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test getting non-existent PR."""
|
||||||
|
mock_httpx_client.request.return_value.status_code = 404
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
result = await gitea_provider.get_pr("owner", "repo", 999)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_prs(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||||
|
"""Test listing pull requests."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=[sample_pr_data, sample_pr_data]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.list_prs("owner", "repo")
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert len(result.pull_requests) == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_prs_with_state_filter(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||||
|
"""Test listing PRs with state filter."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=[sample_pr_data]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.list_prs(
|
||||||
|
"owner", "repo", state=PRState.OPEN
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_merge_pr(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test merging a pull request."""
|
||||||
|
# First call returns merge result
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"sha": "merge-commit-sha"}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.merge_pr(
|
||||||
|
"owner", "repo", 42,
|
||||||
|
merge_strategy=MergeStrategy.SQUASH,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.merge_commit_sha == "merge-commit-sha"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||||
|
"""Test updating a pull request."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=sample_pr_data
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.update_pr(
|
||||||
|
"owner", "repo", 42,
|
||||||
|
title="Updated Title",
|
||||||
|
body="Updated body",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
|
||||||
|
"""Test closing a pull request."""
|
||||||
|
sample_pr_data["state"] = "closed"
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=sample_pr_data
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.close_pr("owner", "repo", 42)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestGiteaBranchOperations:
|
||||||
|
"""Tests for Gitea branch operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_branch(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test getting branch info."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={
|
||||||
|
"name": "main",
|
||||||
|
"commit": {"sha": "abc123"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.get_branch("owner", "repo", "main")
|
||||||
|
|
||||||
|
assert result["name"] == "main"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_remote_branch(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test deleting a remote branch."""
|
||||||
|
mock_httpx_client.request.return_value.status_code = 204
|
||||||
|
|
||||||
|
result = await gitea_provider.delete_remote_branch("owner", "repo", "old-branch")
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestGiteaCommentOperations:
|
||||||
|
"""Tests for Gitea comment operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_pr_comment(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test adding a comment to a PR."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"id": 1, "body": "Test comment"}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.add_pr_comment(
|
||||||
|
"owner", "repo", 42, "Test comment"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["body"] == "Test comment"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_pr_comments(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test listing PR comments."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value=[
|
||||||
|
{"id": 1, "body": "Comment 1"},
|
||||||
|
{"id": 2, "body": "Comment 2"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await gitea_provider.list_pr_comments("owner", "repo", 42)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestGiteaLabelOperations:
|
||||||
|
"""Tests for Gitea label operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_labels(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test adding labels to a PR."""
|
||||||
|
# Use side_effect to return different values for different calls
|
||||||
|
# 1. GET labels (for "bug") - returns existing labels
|
||||||
|
# 2. POST to create "bug" label
|
||||||
|
# 3. GET labels (for "urgent")
|
||||||
|
# 4. POST to create "urgent" label
|
||||||
|
# 5. POST labels to PR
|
||||||
|
# 6. GET issue to return final labels
|
||||||
|
mock_responses = [
|
||||||
|
[{"id": 1, "name": "existing"}], # GET labels (bug not found)
|
||||||
|
{"id": 2, "name": "bug"}, # POST create bug
|
||||||
|
[{"id": 1, "name": "existing"}, {"id": 2, "name": "bug"}], # GET labels (urgent not found)
|
||||||
|
{"id": 3, "name": "urgent"}, # POST create urgent
|
||||||
|
{}, # POST add labels to PR
|
||||||
|
{"labels": [{"name": "bug"}, {"name": "urgent"}]}, # GET issue
|
||||||
|
]
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
|
||||||
|
|
||||||
|
result = await gitea_provider.add_labels(
|
||||||
|
"owner", "repo", 42, ["bug", "urgent"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return updated label list
|
||||||
|
assert isinstance(result, list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_remove_label(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test removing a label from a PR."""
|
||||||
|
# Use side_effect for multiple calls
|
||||||
|
# 1. GET labels to find the label ID
|
||||||
|
# 2. DELETE the label from the PR
|
||||||
|
# 3. GET issue to return remaining labels
|
||||||
|
mock_responses = [
|
||||||
|
[{"id": 1, "name": "bug"}], # GET labels
|
||||||
|
{}, # DELETE label
|
||||||
|
{"labels": []}, # GET issue
|
||||||
|
]
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
|
||||||
|
|
||||||
|
result = await gitea_provider.remove_label(
|
||||||
|
"owner", "repo", 42, "bug"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGiteaReviewerOperations:
|
||||||
|
"""Tests for Gitea reviewer operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_request_review(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test requesting review from users."""
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(return_value={})
|
||||||
|
|
||||||
|
result = await gitea_provider.request_review(
|
||||||
|
"owner", "repo", 42, ["reviewer1", "reviewer2"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == ["reviewer1", "reviewer2"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGiteaErrorHandling:
|
||||||
|
"""Tests for error handling in Gitea provider."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_authentication_error(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test handling authentication errors."""
|
||||||
|
mock_httpx_client.request.return_value.status_code = 401
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
await gitea_provider._request("GET", "/user")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_permission_denied(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test handling permission denied errors."""
|
||||||
|
mock_httpx_client.request.return_value.status_code = 403
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
|
||||||
|
await gitea_provider._request("GET", "/protected")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_error(self, gitea_provider, mock_httpx_client):
|
||||||
|
"""Test handling general API errors."""
|
||||||
|
mock_httpx_client.request.return_value.status_code = 500
|
||||||
|
mock_httpx_client.request.return_value.text = "Internal Server Error"
|
||||||
|
mock_httpx_client.request.return_value.json = MagicMock(
|
||||||
|
return_value={"message": "Server error"}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(APIError):
|
||||||
|
await gitea_provider._request("GET", "/error")
|
||||||
|
|
||||||
|
|
||||||
|
class TestGiteaPRParsing:
|
||||||
|
"""Tests for PR data parsing."""
|
||||||
|
|
||||||
|
def test_parse_pr_open(self, gitea_provider, sample_pr_data):
|
||||||
|
"""Test parsing open PR."""
|
||||||
|
pr_info = gitea_provider._parse_pr(sample_pr_data)
|
||||||
|
|
||||||
|
assert pr_info.number == 42
|
||||||
|
assert pr_info.state == PRState.OPEN
|
||||||
|
assert pr_info.title == "Test PR"
|
||||||
|
assert pr_info.source_branch == "feature-branch"
|
||||||
|
assert pr_info.target_branch == "main"
|
||||||
|
|
||||||
|
def test_parse_pr_merged(self, gitea_provider, sample_pr_data):
|
||||||
|
"""Test parsing merged PR."""
|
||||||
|
sample_pr_data["merged"] = True
|
||||||
|
sample_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
|
||||||
|
|
||||||
|
pr_info = gitea_provider._parse_pr(sample_pr_data)
|
||||||
|
|
||||||
|
assert pr_info.state == PRState.MERGED
|
||||||
|
|
||||||
|
def test_parse_pr_closed(self, gitea_provider, sample_pr_data):
|
||||||
|
"""Test parsing closed PR."""
|
||||||
|
sample_pr_data["state"] = "closed"
|
||||||
|
sample_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
|
||||||
|
|
||||||
|
pr_info = gitea_provider._parse_pr(sample_pr_data)
|
||||||
|
|
||||||
|
assert pr_info.state == PRState.CLOSED
|
||||||
|
|
||||||
|
def test_parse_datetime_iso(self, gitea_provider):
|
||||||
|
"""Test parsing ISO datetime strings."""
|
||||||
|
dt = gitea_provider._parse_datetime("2024-01-15T10:30:00Z")
|
||||||
|
|
||||||
|
assert dt.year == 2024
|
||||||
|
assert dt.month == 1
|
||||||
|
assert dt.day == 15
|
||||||
|
|
||||||
|
def test_parse_datetime_none(self, gitea_provider):
|
||||||
|
"""Test parsing None datetime returns now."""
|
||||||
|
dt = gitea_provider._parse_datetime(None)
|
||||||
|
|
||||||
|
assert dt is not None
|
||||||
|
assert dt.tzinfo is not None
|
||||||
514
mcp-servers/git-ops/tests/test_server.py
Normal file
514
mcp-servers/git-ops/tests/test_server.py
Normal file
@@ -0,0 +1,514 @@
|
|||||||
|
"""
|
||||||
|
Tests for the MCP server and tools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from exceptions import ErrorCode
|
||||||
|
|
||||||
|
|
||||||
|
class TestInputValidation:
|
||||||
|
"""Tests for input validation functions."""
|
||||||
|
|
||||||
|
def test_validate_id_valid(self):
|
||||||
|
"""Test valid IDs pass validation."""
|
||||||
|
from server import _validate_id
|
||||||
|
|
||||||
|
assert _validate_id("test-123", "project_id") is None
|
||||||
|
assert _validate_id("my_project", "project_id") is None
|
||||||
|
assert _validate_id("Agent-001", "agent_id") is None
|
||||||
|
|
||||||
|
def test_validate_id_empty(self):
|
||||||
|
"""Test empty ID fails validation."""
|
||||||
|
from server import _validate_id
|
||||||
|
|
||||||
|
error = _validate_id("", "project_id")
|
||||||
|
assert error is not None
|
||||||
|
assert "required" in error.lower()
|
||||||
|
|
||||||
|
def test_validate_id_too_long(self):
|
||||||
|
"""Test too-long ID fails validation."""
|
||||||
|
from server import _validate_id
|
||||||
|
|
||||||
|
error = _validate_id("a" * 200, "project_id")
|
||||||
|
assert error is not None
|
||||||
|
assert "1-128" in error
|
||||||
|
|
||||||
|
def test_validate_id_invalid_chars(self):
|
||||||
|
"""Test invalid characters fail validation."""
|
||||||
|
from server import _validate_id
|
||||||
|
|
||||||
|
assert _validate_id("test@invalid", "project_id") is not None
|
||||||
|
assert _validate_id("test!project", "project_id") is not None
|
||||||
|
assert _validate_id("test project", "project_id") is not None
|
||||||
|
|
||||||
|
def test_validate_branch_valid(self):
|
||||||
|
"""Test valid branch names."""
|
||||||
|
from server import _validate_branch
|
||||||
|
|
||||||
|
assert _validate_branch("main") is None
|
||||||
|
assert _validate_branch("feature/new-thing") is None
|
||||||
|
assert _validate_branch("release-1.0.0") is None
|
||||||
|
assert _validate_branch("hotfix.urgent") is None
|
||||||
|
|
||||||
|
def test_validate_branch_invalid(self):
|
||||||
|
"""Test invalid branch names."""
|
||||||
|
from server import _validate_branch
|
||||||
|
|
||||||
|
assert _validate_branch("") is not None
|
||||||
|
assert _validate_branch("a" * 300) is not None
|
||||||
|
|
||||||
|
def test_validate_url_valid(self):
|
||||||
|
"""Test valid repository URLs."""
|
||||||
|
from server import _validate_url
|
||||||
|
|
||||||
|
assert _validate_url("https://github.com/owner/repo.git") is None
|
||||||
|
assert _validate_url("https://gitea.example.com/owner/repo") is None
|
||||||
|
assert _validate_url("git@github.com:owner/repo.git") is None
|
||||||
|
|
||||||
|
def test_validate_url_invalid(self):
|
||||||
|
"""Test invalid repository URLs."""
|
||||||
|
from server import _validate_url
|
||||||
|
|
||||||
|
assert _validate_url("") is not None
|
||||||
|
assert _validate_url("not-a-url") is not None
|
||||||
|
assert _validate_url("ftp://invalid.com/repo") is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheck:
|
||||||
|
"""Tests for health check endpoint."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_structure(self):
|
||||||
|
"""Test health check returns proper structure."""
|
||||||
|
from server import health_check
|
||||||
|
|
||||||
|
with patch("server._gitea_provider", None), \
|
||||||
|
patch("server._workspace_manager", None):
|
||||||
|
result = await health_check()
|
||||||
|
|
||||||
|
assert "status" in result
|
||||||
|
assert "service" in result
|
||||||
|
assert "version" in result
|
||||||
|
assert "timestamp" in result
|
||||||
|
assert "dependencies" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_no_providers(self):
|
||||||
|
"""Test health check without providers configured."""
|
||||||
|
from server import health_check
|
||||||
|
|
||||||
|
with patch("server._gitea_provider", None), \
|
||||||
|
patch("server._workspace_manager", None):
|
||||||
|
result = await health_check()
|
||||||
|
|
||||||
|
assert result["dependencies"]["gitea"] == "not configured"
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolRegistry:
|
||||||
|
"""Tests for tool registration."""
|
||||||
|
|
||||||
|
def test_tool_registry_populated(self):
|
||||||
|
"""Test that tools are registered."""
|
||||||
|
from server import _tool_registry
|
||||||
|
|
||||||
|
assert len(_tool_registry) > 0
|
||||||
|
assert "clone_repository" in _tool_registry
|
||||||
|
assert "git_status" in _tool_registry
|
||||||
|
assert "create_branch" in _tool_registry
|
||||||
|
assert "commit" in _tool_registry
|
||||||
|
|
||||||
|
def test_tool_schema_structure(self):
|
||||||
|
"""Test tool schemas have proper structure."""
|
||||||
|
from server import _tool_registry
|
||||||
|
|
||||||
|
for name, info in _tool_registry.items():
|
||||||
|
assert "func" in info
|
||||||
|
assert "description" in info
|
||||||
|
assert "schema" in info
|
||||||
|
assert info["schema"]["type"] == "object"
|
||||||
|
assert "properties" in info["schema"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestCloneRepository:
|
||||||
|
"""Tests for clone_repository tool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clone_invalid_project_id(self):
|
||||||
|
"""Test clone with invalid project ID."""
|
||||||
|
from server import clone_repository
|
||||||
|
|
||||||
|
# Access the underlying function via .fn
|
||||||
|
result = await clone_repository.fn(
|
||||||
|
project_id="invalid@id",
|
||||||
|
agent_id="agent-1",
|
||||||
|
repo_url="https://github.com/owner/repo.git",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "project_id" in result["error"].lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clone_invalid_repo_url(self):
|
||||||
|
"""Test clone with invalid repo URL."""
|
||||||
|
from server import clone_repository
|
||||||
|
|
||||||
|
result = await clone_repository.fn(
|
||||||
|
project_id="valid-project",
|
||||||
|
agent_id="agent-1",
|
||||||
|
repo_url="not-a-valid-url",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "url" in result["error"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGitStatus:
|
||||||
|
"""Tests for git_status tool."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_status_workspace_not_found(self):
|
||||||
|
"""Test status when workspace doesn't exist."""
|
||||||
|
from server import git_status
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager):
|
||||||
|
result = await git_status.fn(
|
||||||
|
project_id="nonexistent",
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert result["code"] == ErrorCode.WORKSPACE_NOT_FOUND.value
|
||||||
|
|
||||||
|
|
||||||
|
class TestBranchOperations:
|
||||||
|
"""Tests for branch operation tools."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_branch_invalid_name(self):
|
||||||
|
"""Test creating branch with invalid name."""
|
||||||
|
from server import create_branch
|
||||||
|
|
||||||
|
result = await create_branch.fn(
|
||||||
|
project_id="test-project",
|
||||||
|
agent_id="agent-1",
|
||||||
|
branch_name="", # Invalid
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_branches_workspace_not_found(self):
|
||||||
|
"""Test listing branches when workspace doesn't exist."""
|
||||||
|
from server import list_branches
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager):
|
||||||
|
result = await list_branches.fn(
|
||||||
|
project_id="nonexistent",
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_checkout_invalid_project(self):
|
||||||
|
"""Test checkout with invalid project ID."""
|
||||||
|
from server import checkout
|
||||||
|
|
||||||
|
result = await checkout.fn(
|
||||||
|
project_id="inv@lid",
|
||||||
|
agent_id="agent-1",
|
||||||
|
ref="main",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestCommitOperations:
|
||||||
|
"""Tests for commit operation tools."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_commit_invalid_project(self):
|
||||||
|
"""Test commit with invalid project ID."""
|
||||||
|
from server import commit
|
||||||
|
|
||||||
|
result = await commit.fn(
|
||||||
|
project_id="inv@lid",
|
||||||
|
agent_id="agent-1",
|
||||||
|
message="Test commit",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestPushPullOperations:
|
||||||
|
"""Tests for push/pull operation tools."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_push_workspace_not_found(self):
|
||||||
|
"""Test push when workspace doesn't exist."""
|
||||||
|
from server import push
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager):
|
||||||
|
result = await push.fn(
|
||||||
|
project_id="nonexistent",
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pull_workspace_not_found(self):
|
||||||
|
"""Test pull when workspace doesn't exist."""
|
||||||
|
from server import pull
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager):
|
||||||
|
result = await pull.fn(
|
||||||
|
project_id="nonexistent",
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiffLogOperations:
|
||||||
|
"""Tests for diff and log operation tools."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_diff_workspace_not_found(self):
|
||||||
|
"""Test diff when workspace doesn't exist."""
|
||||||
|
from server import diff
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager):
|
||||||
|
result = await diff.fn(
|
||||||
|
project_id="nonexistent",
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_log_workspace_not_found(self):
|
||||||
|
"""Test log when workspace doesn't exist."""
|
||||||
|
from server import log
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager):
|
||||||
|
result = await log.fn(
|
||||||
|
project_id="nonexistent",
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestPROperations:
|
||||||
|
"""Tests for pull request operation tools."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_pr_no_repo_url(self):
|
||||||
|
"""Test create PR when workspace has no repo URL."""
|
||||||
|
from models import WorkspaceInfo, WorkspaceState
|
||||||
|
from server import create_pull_request
|
||||||
|
|
||||||
|
mock_workspace = WorkspaceInfo(
|
||||||
|
project_id="test-project",
|
||||||
|
path="/tmp/test",
|
||||||
|
state=WorkspaceState.READY,
|
||||||
|
repo_url=None, # No repo URL
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager):
|
||||||
|
result = await create_pull_request.fn(
|
||||||
|
project_id="test-project",
|
||||||
|
agent_id="agent-1",
|
||||||
|
title="Test PR",
|
||||||
|
source_branch="feature",
|
||||||
|
target_branch="main",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "repository URL" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_prs_invalid_state(self):
|
||||||
|
"""Test list PRs with invalid state filter."""
|
||||||
|
from models import WorkspaceInfo, WorkspaceState
|
||||||
|
from server import list_pull_requests
|
||||||
|
|
||||||
|
mock_workspace = WorkspaceInfo(
|
||||||
|
project_id="test-project",
|
||||||
|
path="/tmp/test",
|
||||||
|
state=WorkspaceState.READY,
|
||||||
|
repo_url="https://gitea.test.com/owner/repo.git",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||||
|
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager), \
|
||||||
|
patch("server._get_provider_for_url", return_value=mock_provider):
|
||||||
|
result = await list_pull_requests.fn(
|
||||||
|
project_id="test-project",
|
||||||
|
agent_id="agent-1",
|
||||||
|
state="invalid-state",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Invalid state" in result["error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_merge_pr_invalid_strategy(self):
|
||||||
|
"""Test merge PR with invalid strategy."""
|
||||||
|
from models import WorkspaceInfo, WorkspaceState
|
||||||
|
from server import merge_pull_request
|
||||||
|
|
||||||
|
mock_workspace = WorkspaceInfo(
|
||||||
|
project_id="test-project",
|
||||||
|
path="/tmp/test",
|
||||||
|
state=WorkspaceState.READY,
|
||||||
|
repo_url="https://gitea.test.com/owner/repo.git",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||||
|
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager), \
|
||||||
|
patch("server._get_provider_for_url", return_value=mock_provider):
|
||||||
|
result = await merge_pull_request.fn(
|
||||||
|
project_id="test-project",
|
||||||
|
agent_id="agent-1",
|
||||||
|
pr_number=42,
|
||||||
|
merge_strategy="invalid-strategy",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "Invalid strategy" in result["error"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkspaceOperations:
|
||||||
|
"""Tests for workspace operation tools."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_workspace_not_found(self):
|
||||||
|
"""Test get workspace when it doesn't exist."""
|
||||||
|
from server import get_workspace
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager):
|
||||||
|
result = await get_workspace.fn(
|
||||||
|
project_id="nonexistent",
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_workspace_success(self):
|
||||||
|
"""Test successful workspace locking."""
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
from models import WorkspaceInfo, WorkspaceState
|
||||||
|
from server import lock_workspace
|
||||||
|
|
||||||
|
mock_workspace = WorkspaceInfo(
|
||||||
|
project_id="test-project",
|
||||||
|
path="/tmp/test",
|
||||||
|
state=WorkspaceState.LOCKED,
|
||||||
|
lock_holder="agent-1",
|
||||||
|
lock_expires=datetime.now(UTC) + timedelta(seconds=300),
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.lock_workspace = AsyncMock(return_value=True)
|
||||||
|
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager):
|
||||||
|
result = await lock_workspace.fn(
|
||||||
|
project_id="test-project",
|
||||||
|
agent_id="agent-1",
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["lock_holder"] == "agent-1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unlock_workspace_success(self):
|
||||||
|
"""Test successful workspace unlocking."""
|
||||||
|
from server import unlock_workspace
|
||||||
|
|
||||||
|
mock_manager = AsyncMock()
|
||||||
|
mock_manager.unlock_workspace = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
with patch("server._workspace_manager", mock_manager):
|
||||||
|
result = await unlock_workspace.fn(
|
||||||
|
project_id="test-project",
|
||||||
|
agent_id="agent-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestJSONRPCEndpoint:
|
||||||
|
"""Tests for the JSON-RPC endpoint."""
|
||||||
|
|
||||||
|
def test_python_type_to_json_schema_str(self):
|
||||||
|
"""Test string type conversion."""
|
||||||
|
from server import _python_type_to_json_schema
|
||||||
|
|
||||||
|
result = _python_type_to_json_schema(str)
|
||||||
|
assert result["type"] == "string"
|
||||||
|
|
||||||
|
def test_python_type_to_json_schema_int(self):
|
||||||
|
"""Test int type conversion."""
|
||||||
|
from server import _python_type_to_json_schema
|
||||||
|
|
||||||
|
result = _python_type_to_json_schema(int)
|
||||||
|
assert result["type"] == "integer"
|
||||||
|
|
||||||
|
def test_python_type_to_json_schema_bool(self):
|
||||||
|
"""Test bool type conversion."""
|
||||||
|
from server import _python_type_to_json_schema
|
||||||
|
|
||||||
|
result = _python_type_to_json_schema(bool)
|
||||||
|
assert result["type"] == "boolean"
|
||||||
|
|
||||||
|
def test_python_type_to_json_schema_list(self):
|
||||||
|
"""Test list type conversion."""
|
||||||
|
|
||||||
|
from server import _python_type_to_json_schema
|
||||||
|
|
||||||
|
result = _python_type_to_json_schema(list[str])
|
||||||
|
assert result["type"] == "array"
|
||||||
|
assert result["items"]["type"] == "string"
|
||||||
334
mcp-servers/git-ops/tests/test_workspace.py
Normal file
334
mcp-servers/git-ops/tests/test_workspace.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
"""
|
||||||
|
Tests for the workspace management module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from exceptions import WorkspaceLockedError, WorkspaceNotFoundError
|
||||||
|
from models import WorkspaceState
|
||||||
|
from workspace import FileLockManager, WorkspaceLock
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkspaceManager:
|
||||||
|
"""Tests for WorkspaceManager."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_workspace(self, workspace_manager, valid_project_id):
|
||||||
|
"""Test creating a new workspace."""
|
||||||
|
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
|
||||||
|
assert workspace.project_id == valid_project_id
|
||||||
|
assert workspace.state == WorkspaceState.INITIALIZING
|
||||||
|
assert Path(workspace.path).exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_workspace_with_repo_url(self, workspace_manager, valid_project_id, sample_repo_url):
|
||||||
|
"""Test creating workspace with repository URL."""
|
||||||
|
workspace = await workspace_manager.create_workspace(
|
||||||
|
valid_project_id, repo_url=sample_repo_url
|
||||||
|
)
|
||||||
|
|
||||||
|
assert workspace.repo_url == sample_repo_url
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_workspace(self, workspace_manager, valid_project_id):
|
||||||
|
"""Test getting an existing workspace."""
|
||||||
|
# Create first
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
|
||||||
|
# Get it
|
||||||
|
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||||
|
|
||||||
|
assert workspace is not None
|
||||||
|
assert workspace.project_id == valid_project_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_workspace_not_found(self, workspace_manager):
|
||||||
|
"""Test getting non-existent workspace."""
|
||||||
|
workspace = await workspace_manager.get_workspace("nonexistent")
|
||||||
|
assert workspace is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_workspace(self, workspace_manager, valid_project_id):
|
||||||
|
"""Test deleting a workspace."""
|
||||||
|
# Create first
|
||||||
|
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
workspace_path = Path(workspace.path)
|
||||||
|
assert workspace_path.exists()
|
||||||
|
|
||||||
|
# Delete
|
||||||
|
result = await workspace_manager.delete_workspace(valid_project_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert not workspace_path.exists()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_nonexistent_workspace(self, workspace_manager):
|
||||||
|
"""Test deleting non-existent workspace returns True."""
|
||||||
|
result = await workspace_manager.delete_workspace("nonexistent")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_workspaces(self, workspace_manager):
|
||||||
|
"""Test listing workspaces."""
|
||||||
|
# Create multiple workspaces
|
||||||
|
await workspace_manager.create_workspace("project-1")
|
||||||
|
await workspace_manager.create_workspace("project-2")
|
||||||
|
await workspace_manager.create_workspace("project-3")
|
||||||
|
|
||||||
|
workspaces = await workspace_manager.list_workspaces()
|
||||||
|
|
||||||
|
assert len(workspaces) >= 3
|
||||||
|
project_ids = [w.project_id for w in workspaces]
|
||||||
|
assert "project-1" in project_ids
|
||||||
|
assert "project-2" in project_ids
|
||||||
|
assert "project-3" in project_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkspaceLocking:
|
||||||
|
"""Tests for workspace locking."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_workspace(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||||
|
"""Test locking a workspace."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
|
||||||
|
result = await workspace_manager.lock_workspace(
|
||||||
|
valid_project_id, valid_agent_id, timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||||
|
assert workspace.state == WorkspaceState.LOCKED
|
||||||
|
assert workspace.lock_holder == valid_agent_id
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_already_locked(self, workspace_manager, valid_project_id):
|
||||||
|
"""Test locking already-locked workspace by different holder."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
await workspace_manager.lock_workspace(valid_project_id, "agent-1", timeout=60)
|
||||||
|
|
||||||
|
with pytest.raises(WorkspaceLockedError):
|
||||||
|
await workspace_manager.lock_workspace(valid_project_id, "agent-2", timeout=60)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_same_holder(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||||
|
"""Test re-locking by same holder extends lock."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id, timeout=60)
|
||||||
|
|
||||||
|
# Same holder can re-lock
|
||||||
|
result = await workspace_manager.lock_workspace(
|
||||||
|
valid_project_id, valid_agent_id, timeout=120
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unlock_workspace(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||||
|
"""Test unlocking a workspace."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
|
||||||
|
|
||||||
|
result = await workspace_manager.unlock_workspace(valid_project_id, valid_agent_id)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||||
|
assert workspace.state == WorkspaceState.READY
|
||||||
|
assert workspace.lock_holder is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unlock_wrong_holder(self, workspace_manager, valid_project_id):
|
||||||
|
"""Test unlock fails with wrong holder."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
await workspace_manager.lock_workspace(valid_project_id, "agent-1")
|
||||||
|
|
||||||
|
with pytest.raises(WorkspaceLockedError):
|
||||||
|
await workspace_manager.unlock_workspace(valid_project_id, "agent-2")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_force_unlock(self, workspace_manager, valid_project_id):
|
||||||
|
"""Test force unlock works regardless of holder."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
await workspace_manager.lock_workspace(valid_project_id, "agent-1")
|
||||||
|
|
||||||
|
result = await workspace_manager.unlock_workspace(
|
||||||
|
valid_project_id, "admin", force=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_nonexistent_workspace(self, workspace_manager, valid_agent_id):
|
||||||
|
"""Test locking non-existent workspace raises error."""
|
||||||
|
with pytest.raises(WorkspaceNotFoundError):
|
||||||
|
await workspace_manager.lock_workspace("nonexistent", valid_agent_id)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkspaceLockContextManager:
|
||||||
|
"""Tests for WorkspaceLock context manager."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_context_manager(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||||
|
"""Test using WorkspaceLock as context manager."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
|
||||||
|
async with WorkspaceLock(
|
||||||
|
workspace_manager, valid_project_id, valid_agent_id
|
||||||
|
) as lock:
|
||||||
|
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||||
|
assert workspace.state == WorkspaceState.LOCKED
|
||||||
|
|
||||||
|
# After exiting context, should be unlocked
|
||||||
|
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||||
|
assert workspace.lock_holder is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lock_context_manager_error(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||||
|
"""Test WorkspaceLock releases on exception."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with WorkspaceLock(
|
||||||
|
workspace_manager, valid_project_id, valid_agent_id
|
||||||
|
):
|
||||||
|
raise ValueError("Test error")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||||
|
assert workspace.lock_holder is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkspaceMetadata:
|
||||||
|
"""Tests for workspace metadata operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_touch_workspace(self, workspace_manager, valid_project_id):
|
||||||
|
"""Test updating workspace access time."""
|
||||||
|
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
original_time = workspace.last_accessed
|
||||||
|
|
||||||
|
await workspace_manager.touch_workspace(valid_project_id)
|
||||||
|
|
||||||
|
updated = await workspace_manager.get_workspace(valid_project_id)
|
||||||
|
assert updated.last_accessed >= original_time
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_workspace_branch(self, workspace_manager, valid_project_id):
|
||||||
|
"""Test updating workspace branch."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
|
||||||
|
await workspace_manager.update_workspace_branch(valid_project_id, "feature-branch")
|
||||||
|
|
||||||
|
workspace = await workspace_manager.get_workspace(valid_project_id)
|
||||||
|
assert workspace.current_branch == "feature-branch"
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkspaceSize:
|
||||||
|
"""Tests for workspace size management."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_size_within_limit(self, workspace_manager, valid_project_id):
|
||||||
|
"""Test size check passes for small workspace."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
result = await workspace_manager.check_size_limit(valid_project_id)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_total_size(self, workspace_manager, valid_project_id):
|
||||||
|
"""Test getting total workspace size."""
|
||||||
|
workspace = await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
|
||||||
|
# Add some content
|
||||||
|
content_file = Path(workspace.path) / "content.txt"
|
||||||
|
content_file.write_text("x" * 1000)
|
||||||
|
|
||||||
|
total_size = await workspace_manager.get_total_size()
|
||||||
|
assert total_size >= 1000
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileLockManager:
|
||||||
|
"""Tests for file-based locking."""
|
||||||
|
|
||||||
|
def test_acquire_lock(self, temp_dir):
|
||||||
|
"""Test acquiring a file lock."""
|
||||||
|
manager = FileLockManager(temp_dir / "locks")
|
||||||
|
|
||||||
|
result = manager.acquire("test-key")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
manager.release("test-key")
|
||||||
|
|
||||||
|
def test_release_lock(self, temp_dir):
|
||||||
|
"""Test releasing a file lock."""
|
||||||
|
manager = FileLockManager(temp_dir / "locks")
|
||||||
|
manager.acquire("test-key")
|
||||||
|
|
||||||
|
result = manager.release("test-key")
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_is_locked(self, temp_dir):
|
||||||
|
"""Test checking if locked."""
|
||||||
|
manager = FileLockManager(temp_dir / "locks")
|
||||||
|
|
||||||
|
assert manager.is_locked("test-key") is False
|
||||||
|
|
||||||
|
manager.acquire("test-key")
|
||||||
|
assert manager.is_locked("test-key") is True
|
||||||
|
|
||||||
|
manager.release("test-key")
|
||||||
|
|
||||||
|
def test_release_nonexistent_lock(self, temp_dir):
|
||||||
|
"""Test releasing a lock that doesn't exist."""
|
||||||
|
manager = FileLockManager(temp_dir / "locks")
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
result = manager.release("nonexistent")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkspaceCleanup:
|
||||||
|
"""Tests for workspace cleanup operations."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cleanup_stale_workspaces(self, workspace_manager, test_settings):
|
||||||
|
"""Test cleaning up stale workspaces."""
|
||||||
|
# Create workspace
|
||||||
|
workspace = await workspace_manager.create_workspace("stale-project")
|
||||||
|
|
||||||
|
# Manually set it as stale by updating metadata
|
||||||
|
await workspace_manager._update_metadata(
|
||||||
|
"stale-project",
|
||||||
|
last_accessed=(datetime.now(UTC) - timedelta(days=30)).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run cleanup
|
||||||
|
cleaned = await workspace_manager.cleanup_stale_workspaces()
|
||||||
|
|
||||||
|
assert cleaned >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_locked_workspace_blocked(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||||
|
"""Test deleting locked workspace is blocked without force."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
|
||||||
|
|
||||||
|
with pytest.raises(WorkspaceLockedError):
|
||||||
|
await workspace_manager.delete_workspace(valid_project_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_locked_workspace_force(self, workspace_manager, valid_project_id, valid_agent_id):
|
||||||
|
"""Test force deleting locked workspace."""
|
||||||
|
await workspace_manager.create_workspace(valid_project_id)
|
||||||
|
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
|
||||||
|
|
||||||
|
result = await workspace_manager.delete_workspace(valid_project_id, force=True)
|
||||||
|
assert result is True
|
||||||
1853
mcp-servers/git-ops/uv.lock
generated
Normal file
1853
mcp-servers/git-ops/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
614
mcp-servers/git-ops/workspace.py
Normal file
614
mcp-servers/git-ops/workspace.py
Normal file
@@ -0,0 +1,614 @@
|
|||||||
|
"""
|
||||||
|
Workspace management for Git Operations MCP Server.
|
||||||
|
|
||||||
|
Handles isolated workspaces for each project, including creation,
|
||||||
|
locking, cleanup, and size management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiofiles # type: ignore[import-untyped]
|
||||||
|
from filelock import FileLock, Timeout
|
||||||
|
|
||||||
|
from config import Settings, get_settings
|
||||||
|
from exceptions import (
|
||||||
|
WorkspaceLockedError,
|
||||||
|
WorkspaceNotFoundError,
|
||||||
|
WorkspaceSizeExceededError,
|
||||||
|
)
|
||||||
|
from models import WorkspaceInfo, WorkspaceState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Metadata file name
|
||||||
|
WORKSPACE_METADATA_FILE = ".syndarix-workspace.json"
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceManager:
|
||||||
|
"""
|
||||||
|
Manages git workspaces for projects.
|
||||||
|
|
||||||
|
Each project gets an isolated workspace directory for git operations.
|
||||||
|
Supports distributed locking via Redis or local file locks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Initialize WorkspaceManager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
settings: Optional settings override
|
||||||
|
"""
|
||||||
|
self.settings = settings or get_settings()
|
||||||
|
self.base_path = self.settings.workspace_base_path
|
||||||
|
self._ensure_base_path()
|
||||||
|
|
||||||
|
def _ensure_base_path(self) -> None:
|
||||||
|
"""Ensure the base workspace directory exists."""
|
||||||
|
self.base_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def _get_workspace_path(self, project_id: str) -> Path:
|
||||||
|
"""Get the path for a project workspace with path traversal protection."""
|
||||||
|
# Sanitize project ID for filesystem
|
||||||
|
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_id)
|
||||||
|
|
||||||
|
# Reject reserved names
|
||||||
|
reserved_names = {".", "..", "con", "prn", "aux", "nul"}
|
||||||
|
if safe_id.lower() in reserved_names:
|
||||||
|
raise ValueError(f"Invalid project ID: reserved name '{project_id}'")
|
||||||
|
|
||||||
|
# Construct path and verify it's within base_path (prevent path traversal)
|
||||||
|
workspace_path = (self.base_path / safe_id).resolve()
|
||||||
|
base_resolved = self.base_path.resolve()
|
||||||
|
|
||||||
|
if not workspace_path.is_relative_to(base_resolved):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid project ID: path traversal detected '{project_id}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return workspace_path
|
||||||
|
|
||||||
|
def _get_lock_path(self, project_id: str) -> Path:
|
||||||
|
"""Get the lock file path for a workspace."""
|
||||||
|
return self._get_workspace_path(project_id) / ".lock"
|
||||||
|
|
||||||
|
def _get_metadata_path(self, project_id: str) -> Path:
|
||||||
|
"""Get the metadata file path for a workspace."""
|
||||||
|
return self._get_workspace_path(project_id) / WORKSPACE_METADATA_FILE
|
||||||
|
|
||||||
|
async def get_workspace(self, project_id: str) -> WorkspaceInfo | None:
|
||||||
|
"""
|
||||||
|
Get workspace info for a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkspaceInfo or None if not found
|
||||||
|
"""
|
||||||
|
workspace_path = self._get_workspace_path(project_id)
|
||||||
|
|
||||||
|
if not workspace_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
metadata = await self._load_metadata(project_id)
|
||||||
|
|
||||||
|
# Calculate size
|
||||||
|
size_bytes = await self._calculate_size(workspace_path)
|
||||||
|
|
||||||
|
# Check lock status
|
||||||
|
lock_holder = None
|
||||||
|
lock_expires = None
|
||||||
|
if metadata:
|
||||||
|
lock_holder = metadata.get("lock_holder")
|
||||||
|
if metadata.get("lock_expires"):
|
||||||
|
lock_expires = datetime.fromisoformat(metadata["lock_expires"])
|
||||||
|
# Clear expired locks
|
||||||
|
if lock_expires < datetime.now(UTC):
|
||||||
|
lock_holder = None
|
||||||
|
lock_expires = None
|
||||||
|
|
||||||
|
# Determine state
|
||||||
|
state = WorkspaceState.READY
|
||||||
|
if lock_holder:
|
||||||
|
state = WorkspaceState.LOCKED
|
||||||
|
|
||||||
|
# Check if stale
|
||||||
|
last_accessed = datetime.now(UTC)
|
||||||
|
if metadata and metadata.get("last_accessed"):
|
||||||
|
last_accessed = datetime.fromisoformat(metadata["last_accessed"])
|
||||||
|
stale_threshold = datetime.now(UTC) - timedelta(
|
||||||
|
days=self.settings.workspace_stale_days
|
||||||
|
)
|
||||||
|
if last_accessed < stale_threshold:
|
||||||
|
state = WorkspaceState.STALE
|
||||||
|
|
||||||
|
return WorkspaceInfo(
|
||||||
|
project_id=project_id,
|
||||||
|
path=str(workspace_path),
|
||||||
|
state=state,
|
||||||
|
repo_url=metadata.get("repo_url") if metadata else None,
|
||||||
|
current_branch=metadata.get("current_branch") if metadata else None,
|
||||||
|
last_accessed=last_accessed,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
lock_holder=lock_holder,
|
||||||
|
lock_expires=lock_expires,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_workspace(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
repo_url: str | None = None,
|
||||||
|
) -> WorkspaceInfo:
|
||||||
|
"""
|
||||||
|
Create or get a workspace for a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
repo_url: Optional repository URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkspaceInfo for the workspace
|
||||||
|
"""
|
||||||
|
workspace_path = self._get_workspace_path(project_id)
|
||||||
|
|
||||||
|
if workspace_path.exists():
|
||||||
|
# Workspace already exists, update metadata
|
||||||
|
await self._update_metadata(project_id, repo_url=repo_url)
|
||||||
|
workspace = await self.get_workspace(project_id)
|
||||||
|
if workspace:
|
||||||
|
return workspace
|
||||||
|
|
||||||
|
# Create workspace directory
|
||||||
|
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create initial metadata
|
||||||
|
metadata = {
|
||||||
|
"project_id": project_id,
|
||||||
|
"repo_url": repo_url,
|
||||||
|
"created_at": datetime.now(UTC).isoformat(),
|
||||||
|
"last_accessed": datetime.now(UTC).isoformat(),
|
||||||
|
}
|
||||||
|
await self._save_metadata(project_id, metadata)
|
||||||
|
|
||||||
|
return WorkspaceInfo(
|
||||||
|
project_id=project_id,
|
||||||
|
path=str(workspace_path),
|
||||||
|
state=WorkspaceState.INITIALIZING,
|
||||||
|
repo_url=repo_url,
|
||||||
|
last_accessed=datetime.now(UTC),
|
||||||
|
size_bytes=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_workspace(self, project_id: str, force: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
force: Force delete even if locked
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted
|
||||||
|
"""
|
||||||
|
workspace_path = self._get_workspace_path(project_id)
|
||||||
|
|
||||||
|
if not workspace_path.exists():
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check lock
|
||||||
|
if not force:
|
||||||
|
workspace = await self.get_workspace(project_id)
|
||||||
|
if workspace and workspace.state == WorkspaceState.LOCKED:
|
||||||
|
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use shutil.rmtree for robust deletion
|
||||||
|
shutil.rmtree(workspace_path)
|
||||||
|
logger.info(f"Deleted workspace for project: {project_id}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete workspace {project_id}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def lock_workspace(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
holder: str,
|
||||||
|
timeout: int | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Acquire a lock on a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
holder: Lock holder identifier (agent_id)
|
||||||
|
timeout: Lock timeout in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if lock acquired
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
WorkspaceNotFoundError: If workspace doesn't exist
|
||||||
|
WorkspaceLockedError: If already locked by another
|
||||||
|
"""
|
||||||
|
workspace = await self.get_workspace(project_id)
|
||||||
|
|
||||||
|
if workspace is None:
|
||||||
|
raise WorkspaceNotFoundError(project_id)
|
||||||
|
|
||||||
|
# Check if already locked by someone else
|
||||||
|
if workspace.state == WorkspaceState.LOCKED and workspace.lock_holder != holder:
|
||||||
|
# Check if lock expired
|
||||||
|
if workspace.lock_expires and workspace.lock_expires > datetime.now(UTC):
|
||||||
|
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||||
|
|
||||||
|
# Calculate lock expiry
|
||||||
|
lock_timeout = timeout or self.settings.workspace_lock_timeout
|
||||||
|
lock_expires = datetime.now(UTC) + timedelta(seconds=lock_timeout)
|
||||||
|
|
||||||
|
# Update metadata with lock info
|
||||||
|
await self._update_metadata(
|
||||||
|
project_id,
|
||||||
|
lock_holder=holder,
|
||||||
|
lock_expires=lock_expires.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Workspace {project_id} locked by {holder}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def unlock_workspace(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
holder: str,
|
||||||
|
force: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Release a lock on a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
holder: Lock holder identifier
|
||||||
|
force: Force unlock regardless of holder
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if unlocked
|
||||||
|
"""
|
||||||
|
workspace = await self.get_workspace(project_id)
|
||||||
|
|
||||||
|
if workspace is None:
|
||||||
|
raise WorkspaceNotFoundError(project_id)
|
||||||
|
|
||||||
|
# Verify holder
|
||||||
|
if not force and workspace.lock_holder and workspace.lock_holder != holder:
|
||||||
|
raise WorkspaceLockedError(project_id, workspace.lock_holder)
|
||||||
|
|
||||||
|
# Clear lock
|
||||||
|
await self._update_metadata(
|
||||||
|
project_id,
|
||||||
|
lock_holder=None,
|
||||||
|
lock_expires=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Workspace {project_id} unlocked by {holder}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def touch_workspace(self, project_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Update last accessed time for a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
"""
|
||||||
|
await self._update_metadata(
|
||||||
|
project_id,
|
||||||
|
last_accessed=datetime.now(UTC).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_workspace_branch(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
branch: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Update the current branch in workspace metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
branch: Current branch name
|
||||||
|
"""
|
||||||
|
await self._update_metadata(
|
||||||
|
project_id,
|
||||||
|
current_branch=branch,
|
||||||
|
last_accessed=datetime.now(UTC).isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def check_size_limit(self, project_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if workspace exceeds size limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if within limits
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
WorkspaceSizeExceededError: If size exceeds limit
|
||||||
|
"""
|
||||||
|
workspace_path = self._get_workspace_path(project_id)
|
||||||
|
|
||||||
|
if not workspace_path.exists():
|
||||||
|
return True
|
||||||
|
|
||||||
|
size_bytes = await self._calculate_size(workspace_path)
|
||||||
|
size_gb = size_bytes / (1024**3)
|
||||||
|
max_size_gb = self.settings.workspace_max_size_gb
|
||||||
|
|
||||||
|
if size_gb > max_size_gb:
|
||||||
|
raise WorkspaceSizeExceededError(project_id, size_gb, max_size_gb)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def list_workspaces(
|
||||||
|
self,
|
||||||
|
include_stale: bool = False,
|
||||||
|
) -> list[WorkspaceInfo]:
|
||||||
|
"""
|
||||||
|
List all workspaces.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_stale: Include stale workspaces
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of WorkspaceInfo
|
||||||
|
"""
|
||||||
|
workspaces: list[WorkspaceInfo] = []
|
||||||
|
|
||||||
|
if not self.base_path.exists():
|
||||||
|
return workspaces
|
||||||
|
|
||||||
|
for entry in self.base_path.iterdir():
|
||||||
|
if entry.is_dir() and not entry.name.startswith("."):
|
||||||
|
# Extract project_id from directory name
|
||||||
|
workspace = await self.get_workspace(entry.name)
|
||||||
|
if workspace:
|
||||||
|
if not include_stale and workspace.state == WorkspaceState.STALE:
|
||||||
|
continue
|
||||||
|
workspaces.append(workspace)
|
||||||
|
|
||||||
|
return workspaces
|
||||||
|
|
||||||
|
async def cleanup_stale_workspaces(self) -> int:
|
||||||
|
"""
|
||||||
|
Clean up stale workspaces.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of workspaces cleaned up
|
||||||
|
"""
|
||||||
|
cleaned = 0
|
||||||
|
workspaces = await self.list_workspaces(include_stale=True)
|
||||||
|
|
||||||
|
for workspace in workspaces:
|
||||||
|
if workspace.state == WorkspaceState.STALE:
|
||||||
|
try:
|
||||||
|
await self.delete_workspace(workspace.project_id, force=True)
|
||||||
|
cleaned += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to cleanup stale workspace {workspace.project_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cleaned > 0:
|
||||||
|
logger.info(f"Cleaned up {cleaned} stale workspaces")
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
async def get_total_size(self) -> int:
|
||||||
|
"""
|
||||||
|
Get total size of all workspaces.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total size in bytes
|
||||||
|
"""
|
||||||
|
return await self._calculate_size(self.base_path)
|
||||||
|
|
||||||
|
# Private methods
|
||||||
|
|
||||||
|
async def _load_metadata(self, project_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Load workspace metadata from file."""
|
||||||
|
metadata_path = self._get_metadata_path(project_id)
|
||||||
|
|
||||||
|
if not metadata_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(metadata_path) as f:
|
||||||
|
content = await f.read()
|
||||||
|
return json.loads(content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load metadata for {project_id}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _save_metadata(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Save workspace metadata to file."""
|
||||||
|
metadata_path = self._get_metadata_path(project_id)
|
||||||
|
|
||||||
|
# Ensure parent directory exists
|
||||||
|
metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(metadata_path, "w") as f:
|
||||||
|
await f.write(json.dumps(metadata, indent=2))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save metadata for {project_id}: {e}")
|
||||||
|
|
||||||
|
async def _update_metadata(
|
||||||
|
self,
|
||||||
|
project_id: str,
|
||||||
|
**updates: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Update specific fields in workspace metadata."""
|
||||||
|
metadata = await self._load_metadata(project_id) or {}
|
||||||
|
|
||||||
|
# Handle None values (to clear fields)
|
||||||
|
for key, value in updates.items():
|
||||||
|
if value is None:
|
||||||
|
metadata.pop(key, None)
|
||||||
|
else:
|
||||||
|
metadata[key] = value
|
||||||
|
|
||||||
|
await self._save_metadata(project_id, metadata)
|
||||||
|
|
||||||
|
async def _calculate_size(self, path: Path) -> int:
|
||||||
|
"""Calculate total size of a directory."""
|
||||||
|
|
||||||
|
def _calc_size() -> int:
|
||||||
|
total = 0
|
||||||
|
try:
|
||||||
|
for entry in path.rglob("*"):
|
||||||
|
if entry.is_file():
|
||||||
|
try:
|
||||||
|
total += entry.stat().st_size
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return total
|
||||||
|
|
||||||
|
# Run in executor for async compatibility
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(None, _calc_size)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceLock:
|
||||||
|
"""
|
||||||
|
Context manager for workspace locking.
|
||||||
|
|
||||||
|
Provides automatic locking/unlocking with proper cleanup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
manager: WorkspaceManager,
|
||||||
|
project_id: str,
|
||||||
|
holder: str,
|
||||||
|
timeout: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize workspace lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
manager: WorkspaceManager instance
|
||||||
|
project_id: Project identifier
|
||||||
|
holder: Lock holder identifier
|
||||||
|
timeout: Lock timeout in seconds
|
||||||
|
"""
|
||||||
|
self.manager = manager
|
||||||
|
self.project_id = project_id
|
||||||
|
self.holder = holder
|
||||||
|
self.timeout = timeout
|
||||||
|
self._acquired = False
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "WorkspaceLock":
|
||||||
|
"""Acquire lock on enter."""
|
||||||
|
await self.manager.lock_workspace(
|
||||||
|
self.project_id,
|
||||||
|
self.holder,
|
||||||
|
self.timeout,
|
||||||
|
)
|
||||||
|
self._acquired = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||||
|
"""Release lock on exit."""
|
||||||
|
if self._acquired:
|
||||||
|
try:
|
||||||
|
await self.manager.unlock_workspace(
|
||||||
|
self.project_id,
|
||||||
|
self.holder,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to release lock for {self.project_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
class FileLockManager:
|
||||||
|
"""
|
||||||
|
File-based locking for single-instance deployments.
|
||||||
|
|
||||||
|
Uses filelock for local locking when Redis is not available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lock_dir: Path) -> None:
|
||||||
|
"""
|
||||||
|
Initialize file lock manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lock_dir: Directory for lock files
|
||||||
|
"""
|
||||||
|
self.lock_dir = lock_dir
|
||||||
|
self.lock_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._locks: dict[str, FileLock] = {}
|
||||||
|
|
||||||
|
def _get_lock(self, key: str) -> FileLock:
|
||||||
|
"""Get or create a file lock for a key."""
|
||||||
|
if key not in self._locks:
|
||||||
|
lock_path = self.lock_dir / f"{key}.lock"
|
||||||
|
self._locks[key] = FileLock(lock_path)
|
||||||
|
return self._locks[key]
|
||||||
|
|
||||||
|
def acquire(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
timeout: float = 10.0,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Acquire a lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Lock key
|
||||||
|
timeout: Timeout in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if acquired
|
||||||
|
"""
|
||||||
|
lock = self._get_lock(key)
|
||||||
|
try:
|
||||||
|
lock.acquire(timeout=timeout)
|
||||||
|
return True
|
||||||
|
except Timeout:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def release(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Release a lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Lock key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if released
|
||||||
|
"""
|
||||||
|
if key in self._locks:
|
||||||
|
try:
|
||||||
|
self._locks[key].release()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_locked(self, key: str) -> bool:
|
||||||
|
"""Check if a key is locked."""
|
||||||
|
lock = self._get_lock(key)
|
||||||
|
return lock.is_locked
|
||||||
Reference in New Issue
Block a user