3 Commits

Author SHA1 Message Date
Felipe Cardoso
76d7de5334 **feat(git-ops): enhance MCP server with Git provider updates and SSRF protection**
- Added `mcp-git-ops` service to `docker-compose.dev.yml` with health checks and configurations.
- Integrated SSRF protection in repository URL validation for enhanced security.
- Expanded `pyproject.toml` mypy settings and adjusted code to meet stricter type checking.
- Improved workspace management and GitWrapper operations with error handling refinements.
- Updated input validation, branching, and repository operations to align with new error structure.
- Shut down thread pool executor gracefully during server cleanup.
2026-01-07 09:17:00 +01:00
Felipe Cardoso
1779239c07 feat(git-ops): add GitHub provider with auto-detection
Implements GitHub API provider following the same pattern as Gitea:
- Full PR operations (create, get, list, merge, update, close)
- Branch operations via API
- Comment and label management
- Reviewer request support
- Rate limit error handling

Server enhancements:
- Auto-detect provider from repository URL (github.com vs custom Gitea)
- Initialize GitHub provider when token is configured
- Health check includes both provider statuses
- Token selection based on repo URL for clone/push operations

Refs: #110

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 20:55:22 +01:00
Felipe Cardoso
9dfa76aa41 feat(mcp): implement Git Operations MCP server with Gitea provider
Implements the Git Operations MCP server (Issue #58) providing:

Core features:
- GitPython wrapper for local repository operations (clone, commit, push, pull, diff, log)
- Branch management (create, delete, list, checkout)
- Workspace isolation per project with file-based locking
- Gitea provider for remote PR operations

MCP Tools (17 registered):
- clone_repository, git_status, create_branch, list_branches
- checkout, commit, push, pull, diff, log
- create_pull_request, get_pull_request, list_pull_requests
- merge_pull_request, get_workspace, lock_workspace, unlock_workspace

Technical details:
- FastMCP + FastAPI with JSON-RPC 2.0 protocol
- pydantic-settings for configuration (env prefix: GIT_OPS_)
- Comprehensive error hierarchy with structured codes
- 131 tests passing with 67% coverage
- Async operations via ThreadPoolExecutor

Closes: #105, #106, #107, #108, #109

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-06 20:48:20 +01:00
24 changed files with 11458 additions and 0 deletions

View File

@@ -47,6 +47,7 @@ help:
@echo " cd backend && make help - Backend-specific commands"
@echo " cd mcp-servers/llm-gateway && make - LLM Gateway commands"
@echo " cd mcp-servers/knowledge-base && make - Knowledge Base commands"
@echo " cd mcp-servers/git-ops && make - Git Operations commands"
@echo " cd frontend && npm run - Frontend-specific commands"
# ============================================================================
@@ -138,6 +139,9 @@ test-mcp:
@echo ""
@echo "=== Knowledge Base ==="
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v
@echo ""
@echo "=== Git Operations ==="
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v
test-frontend:
@echo "Running frontend tests..."
@@ -158,6 +162,9 @@ test-cov:
@echo ""
@echo "=== Knowledge Base Coverage ==="
@cd mcp-servers/knowledge-base && uv run pytest tests/ -v --cov=. --cov-report=term-missing
@echo ""
@echo "=== Git Operations Coverage ==="
@cd mcp-servers/git-ops && IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing
test-integration:
@echo "Running MCP integration tests..."
@@ -178,6 +185,9 @@ format-all:
@echo "Formatting Knowledge Base..."
@cd mcp-servers/knowledge-base && make format
@echo ""
@echo "Formatting Git Operations..."
@cd mcp-servers/git-ops && make format
@echo ""
@echo "Formatting frontend..."
@cd frontend && npm run format
@echo ""
@@ -197,6 +207,9 @@ validate:
@echo "Validating Knowledge Base..."
@cd mcp-servers/knowledge-base && make validate
@echo ""
@echo "Validating Git Operations..."
@cd mcp-servers/git-ops && make validate
@echo ""
@echo "All validations passed!"
validate-all: validate

View File

@@ -96,6 +96,38 @@ services:
- app-network
restart: unless-stopped
mcp-git-ops:
build:
context: ./mcp-servers/git-ops
dockerfile: Dockerfile
ports:
- "8003:8003"
env_file:
- .env
environment:
# GIT_OPS_ prefix required by pydantic-settings config
- GIT_OPS_HOST=0.0.0.0
- GIT_OPS_PORT=8003
- GIT_OPS_REDIS_URL=redis://redis:6379/3
- GIT_OPS_GITEA_BASE_URL=${GITEA_BASE_URL}
- GIT_OPS_GITEA_TOKEN=${GITEA_TOKEN}
- GIT_OPS_GITHUB_TOKEN=${GITHUB_TOKEN}
- ENVIRONMENT=development
volumes:
- git_workspaces_dev:/workspaces
depends_on:
redis:
condition: service_healthy
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
networks:
- app-network
restart: unless-stopped
backend:
build:
context: ./backend
@@ -119,6 +151,7 @@ services:
# MCP Server URLs
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
- GIT_OPS_URL=http://mcp-git-ops:8003
depends_on:
db:
condition: service_healthy
@@ -128,6 +161,8 @@ services:
condition: service_healthy
mcp-knowledge-base:
condition: service_healthy
mcp-git-ops:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 10s
@@ -155,6 +190,7 @@ services:
# MCP Server URLs (agents need access to MCP)
- LLM_GATEWAY_URL=http://mcp-llm-gateway:8001
- KNOWLEDGE_BASE_URL=http://mcp-knowledge-base:8002
- GIT_OPS_URL=http://mcp-git-ops:8003
depends_on:
db:
condition: service_healthy
@@ -164,6 +200,8 @@ services:
condition: service_healthy
mcp-knowledge-base:
condition: service_healthy
mcp-git-ops:
condition: service_healthy
networks:
- app-network
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "agent", "-l", "info", "-c", "4"]
@@ -181,11 +219,14 @@ services:
- DATABASE_URL=${DATABASE_URL}
- REDIS_URL=redis://redis:6379/0
- CELERY_QUEUE=git
- GIT_OPS_URL=http://mcp-git-ops:8003
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
mcp-git-ops:
condition: service_healthy
networks:
- app-network
command: ["celery", "-A", "app.celery_app", "worker", "-Q", "git", "-l", "info", "-c", "2"]
@@ -260,6 +301,7 @@ services:
volumes:
postgres_data_dev:
redis_data_dev:
git_workspaces_dev:
frontend_dev_modules:
frontend_dev_next:

View File

@@ -0,0 +1,67 @@
# Git Operations MCP Server Dockerfile
# Multi-stage build for smaller production image
FROM python:3.12-slim AS builder
# Install build dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
git \
&& rm -rf /var/lib/apt/lists/*
# Install uv for fast package management
RUN pip install --no-cache-dir uv
# Create app directory
WORKDIR /app
# Copy dependency files
COPY pyproject.toml .
# Install dependencies with uv
RUN uv pip install --system --no-cache .
# Production stage
FROM python:3.12-slim
# Install runtime dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
git \
openssh-client \
&& rm -rf /var/lib/apt/lists/*
# Create non-root user
RUN useradd --create-home --shell /bin/bash syndarix
# Create workspace directory
RUN mkdir -p /var/syndarix/workspaces && chown -R syndarix:syndarix /var/syndarix
# Create app directory
WORKDIR /app
# Copy installed packages from builder
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
COPY --from=builder /usr/local/bin /usr/local/bin
# Copy application code
COPY --chown=syndarix:syndarix . .
# Set Python path
ENV PYTHONPATH=/app
ENV PYTHONUNBUFFERED=1
# Configure git for the container
RUN git config --global --add safe.directory '*'
# Switch to non-root user
USER syndarix
# Expose port
EXPOSE 8003
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import httpx; httpx.get('http://localhost:8003/health').raise_for_status()" || exit 1
# Run the server
CMD ["python", "server.py"]

View File

@@ -0,0 +1,88 @@
.PHONY: help install install-dev lint lint-fix format format-check type-check test test-cov validate clean run
# Ensure commands in this project don't inherit an external Python virtualenv
# (prevents uv warnings about mismatched VIRTUAL_ENV when running from repo root)
unexport VIRTUAL_ENV
# Default target
help:
@echo "Git Operations MCP Server - Development Commands"
@echo ""
@echo "Setup:"
@echo " make install - Install production dependencies"
@echo " make install-dev - Install development dependencies"
@echo ""
@echo "Quality Checks:"
@echo " make lint - Run Ruff linter"
@echo " make lint-fix - Run Ruff linter with auto-fix"
@echo " make format - Format code with Ruff"
@echo " make format-check - Check if code is formatted"
@echo " make type-check - Run mypy type checker"
@echo ""
@echo "Testing:"
@echo " make test - Run pytest"
@echo " make test-cov - Run pytest with coverage"
@echo ""
@echo "All-in-one:"
@echo " make validate - Run all checks (lint + format + types)"
@echo ""
@echo "Running:"
@echo " make run - Run the server locally"
@echo ""
@echo "Cleanup:"
@echo " make clean - Remove cache and build artifacts"
# Setup
install:
@echo "Installing production dependencies..."
@uv pip install -e .
install-dev:
@echo "Installing development dependencies..."
@uv pip install -e ".[dev]"
# Quality checks
lint:
@echo "Running Ruff linter..."
@uv run ruff check .
lint-fix:
@echo "Running Ruff linter with auto-fix..."
@uv run ruff check --fix .
format:
@echo "Formatting code..."
@uv run ruff format .
format-check:
@echo "Checking code formatting..."
@uv run ruff format --check .
type-check:
@echo "Running mypy..."
@uv run python -m mypy server.py config.py models.py exceptions.py git_wrapper.py workspace.py providers/ --explicit-package-bases
# Testing
test:
@echo "Running tests..."
@IS_TEST=True uv run pytest tests/ -v
test-cov:
@echo "Running tests with coverage..."
@IS_TEST=True uv run pytest tests/ -v --cov=. --cov-report=term-missing --cov-report=html
# All-in-one validation
validate: lint format-check type-check
@echo "All validations passed!"
# Running
run:
@echo "Starting Git Operations server..."
@uv run python server.py
# Cleanup
clean:
@echo "Cleaning up..."
@rm -rf __pycache__ .pytest_cache .mypy_cache .ruff_cache .coverage htmlcov
@find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
@find . -type f -name "*.pyc" -delete 2>/dev/null || true

View File

@@ -0,0 +1,179 @@
"""
Git Operations MCP Server.
Provides git repository management, branching, commits, and PR workflows
for Syndarix AI agents.
"""
__version__ = "0.1.0"
from config import Settings, get_settings, is_test_mode, reset_settings
from exceptions import (
APIError,
AuthenticationError,
BranchExistsError,
BranchNotFoundError,
CheckoutError,
CloneError,
CommitError,
CredentialError,
CredentialNotFoundError,
DirtyWorkspaceError,
ErrorCode,
GitError,
GitOpsError,
InvalidRefError,
MergeConflictError,
PRError,
PRNotFoundError,
ProviderError,
ProviderNotFoundError,
PullError,
PushError,
WorkspaceError,
WorkspaceLockedError,
WorkspaceNotFoundError,
WorkspaceSizeExceededError,
)
from models import (
BranchInfo,
BranchRequest,
BranchResult,
CheckoutRequest,
CheckoutResult,
CloneRequest,
CloneResult,
CommitInfo,
CommitRequest,
CommitResult,
CreatePRRequest,
CreatePRResult,
DiffHunk,
DiffRequest,
DiffResult,
FileChange,
FileChangeType,
FileDiff,
GetPRRequest,
GetPRResult,
GetWorkspaceRequest,
GetWorkspaceResult,
HealthStatus,
ListBranchesRequest,
ListBranchesResult,
ListPRsRequest,
ListPRsResult,
LockWorkspaceRequest,
LockWorkspaceResult,
LogRequest,
LogResult,
MergePRRequest,
MergePRResult,
MergeStrategy,
PRInfo,
ProviderStatus,
ProviderType,
PRState,
PullRequest,
PullResult,
PushRequest,
PushResult,
StatusRequest,
StatusResult,
UnlockWorkspaceRequest,
UnlockWorkspaceResult,
UpdatePRRequest,
UpdatePRResult,
WorkspaceInfo,
WorkspaceState,
)
__all__ = [
# Version
"__version__",
# Config
"Settings",
"get_settings",
"reset_settings",
"is_test_mode",
# Error codes
"ErrorCode",
# Exceptions
"GitOpsError",
"WorkspaceError",
"WorkspaceNotFoundError",
"WorkspaceLockedError",
"WorkspaceSizeExceededError",
"GitError",
"CloneError",
"CheckoutError",
"CommitError",
"PushError",
"PullError",
"MergeConflictError",
"BranchExistsError",
"BranchNotFoundError",
"InvalidRefError",
"DirtyWorkspaceError",
"ProviderError",
"AuthenticationError",
"ProviderNotFoundError",
"PRError",
"PRNotFoundError",
"APIError",
"CredentialError",
"CredentialNotFoundError",
# Enums
"FileChangeType",
"MergeStrategy",
"PRState",
"ProviderType",
"WorkspaceState",
# Dataclasses
"FileChange",
"BranchInfo",
"CommitInfo",
"DiffHunk",
"FileDiff",
"PRInfo",
"WorkspaceInfo",
# Request/Response models
"CloneRequest",
"CloneResult",
"StatusRequest",
"StatusResult",
"BranchRequest",
"BranchResult",
"ListBranchesRequest",
"ListBranchesResult",
"CheckoutRequest",
"CheckoutResult",
"CommitRequest",
"CommitResult",
"PushRequest",
"PushResult",
"PullRequest",
"PullResult",
"DiffRequest",
"DiffResult",
"LogRequest",
"LogResult",
"CreatePRRequest",
"CreatePRResult",
"GetPRRequest",
"GetPRResult",
"ListPRsRequest",
"ListPRsResult",
"MergePRRequest",
"MergePRResult",
"UpdatePRRequest",
"UpdatePRResult",
"GetWorkspaceRequest",
"GetWorkspaceResult",
"LockWorkspaceRequest",
"LockWorkspaceResult",
"UnlockWorkspaceRequest",
"UnlockWorkspaceResult",
"HealthStatus",
"ProviderStatus",
]

View File

@@ -0,0 +1,155 @@
"""
Configuration for Git Operations MCP Server.
Uses pydantic-settings for environment variable loading.
"""
import os
from pathlib import Path
from pydantic import Field
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings loaded from environment."""
# Server settings
host: str = Field(default="0.0.0.0", description="Server host")
port: int = Field(default=8003, description="Server port")
debug: bool = Field(default=False, description="Debug mode")
# Workspace settings
workspace_base_path: Path = Field(
default=Path("/var/syndarix/workspaces"),
description="Base path for git workspaces",
)
workspace_max_size_gb: float = Field(
default=10.0,
description="Maximum size per workspace in GB",
)
workspace_stale_days: int = Field(
default=7,
description="Days after which unused workspace is considered stale",
)
workspace_lock_timeout: int = Field(
default=300,
description="Workspace lock timeout in seconds",
)
# Git settings
git_timeout: int = Field(
default=120,
description="Default timeout for git operations in seconds",
)
git_clone_timeout: int = Field(
default=600,
description="Timeout for clone operations in seconds",
)
git_author_name: str = Field(
default="Syndarix Agent",
description="Default author name for commits",
)
git_author_email: str = Field(
default="agent@syndarix.ai",
description="Default author email for commits",
)
git_max_diff_lines: int = Field(
default=10000,
description="Maximum lines in diff output",
)
# Redis settings (for distributed locking)
redis_url: str = Field(
default="redis://localhost:6379/0",
description="Redis connection URL",
)
# Provider settings
gitea_base_url: str = Field(
default="",
description="Gitea API base URL (e.g., https://gitea.example.com)",
)
gitea_token: str = Field(
default="",
description="Gitea API token",
)
github_token: str = Field(
default="",
description="GitHub API token",
)
github_api_url: str = Field(
default="https://api.github.com",
description="GitHub API URL (for Enterprise)",
)
gitlab_token: str = Field(
default="",
description="GitLab API token",
)
gitlab_url: str = Field(
default="https://gitlab.com",
description="GitLab URL (for self-hosted)",
)
# Rate limiting
rate_limit_requests: int = Field(
default=100,
description="Max API requests per minute per provider",
)
rate_limit_window: int = Field(
default=60,
description="Rate limit window in seconds",
)
# Retry settings
retry_attempts: int = Field(
default=3,
description="Number of retry attempts for failed operations",
)
retry_delay: float = Field(
default=1.0,
description="Initial retry delay in seconds",
)
retry_max_delay: float = Field(
default=30.0,
description="Maximum retry delay in seconds",
)
# Security settings
allowed_hosts: list[str] = Field(
default_factory=list,
description="Allowed git host domains (empty = all)",
)
max_clone_size_mb: int = Field(
default=500,
description="Maximum repository size for clone in MB",
)
enable_force_push: bool = Field(
default=False,
description="Allow force push operations",
)
model_config = {"env_prefix": "GIT_OPS_", "env_file": ".env", "extra": "ignore"}
# Global settings instance (lazy initialization)
_settings: Settings | None = None
def get_settings() -> Settings:
"""Get the global settings instance."""
global _settings
if _settings is None:
_settings = Settings()
return _settings
def reset_settings() -> None:
"""Reset the global settings (for testing)."""
global _settings
_settings = None
def is_test_mode() -> bool:
"""Check if running in test mode."""
return os.getenv("IS_TEST", "").lower() in ("true", "1", "yes")

View File

@@ -0,0 +1,359 @@
"""
Exception hierarchy for Git Operations MCP Server.
Provides structured error handling with error codes for MCP responses.
"""
from enum import Enum
from typing import Any
class ErrorCode(str, Enum):
"""Error codes for Git Operations errors."""
# General errors (1xxx)
INTERNAL_ERROR = "GIT_1000"
INVALID_REQUEST = "GIT_1001"
NOT_FOUND = "GIT_1002"
PERMISSION_DENIED = "GIT_1003"
TIMEOUT = "GIT_1004"
RATE_LIMITED = "GIT_1005"
# Workspace errors (2xxx)
WORKSPACE_NOT_FOUND = "GIT_2000"
WORKSPACE_LOCKED = "GIT_2001"
WORKSPACE_SIZE_EXCEEDED = "GIT_2002"
WORKSPACE_CREATE_FAILED = "GIT_2003"
WORKSPACE_DELETE_FAILED = "GIT_2004"
# Git operation errors (3xxx)
CLONE_FAILED = "GIT_3000"
CHECKOUT_FAILED = "GIT_3001"
COMMIT_FAILED = "GIT_3002"
PUSH_FAILED = "GIT_3003"
PULL_FAILED = "GIT_3004"
MERGE_CONFLICT = "GIT_3005"
BRANCH_EXISTS = "GIT_3006"
BRANCH_NOT_FOUND = "GIT_3007"
INVALID_REF = "GIT_3008"
DIRTY_WORKSPACE = "GIT_3009"
UNCOMMITTED_CHANGES = "GIT_3010"
FETCH_FAILED = "GIT_3011"
RESET_FAILED = "GIT_3012"
# Provider errors (4xxx)
PROVIDER_ERROR = "GIT_4000"
PROVIDER_AUTH_FAILED = "GIT_4001"
PROVIDER_NOT_FOUND = "GIT_4002"
PR_CREATE_FAILED = "GIT_4003"
PR_MERGE_FAILED = "GIT_4004"
PR_NOT_FOUND = "GIT_4005"
API_ERROR = "GIT_4006"
# Credential errors (5xxx)
CREDENTIAL_ERROR = "GIT_5000"
CREDENTIAL_NOT_FOUND = "GIT_5001"
CREDENTIAL_INVALID = "GIT_5002"
SSH_KEY_ERROR = "GIT_5003"
class GitOpsError(Exception):
"""Base exception for Git Operations errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.INTERNAL_ERROR,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message)
self.message = message
self.code = code
self.details = details or {}
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for MCP response."""
result: dict[str, Any] = {
"error": self.message,
"code": self.code.value,
}
if self.details:
result["details"] = self.details
return result
# Workspace Errors
class WorkspaceError(GitOpsError):
"""Base exception for workspace-related errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.WORKSPACE_NOT_FOUND,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message, code, details)
class WorkspaceNotFoundError(WorkspaceError):
"""Workspace does not exist."""
def __init__(self, project_id: str) -> None:
super().__init__(
f"Workspace not found for project: {project_id}",
ErrorCode.WORKSPACE_NOT_FOUND,
{"project_id": project_id},
)
class WorkspaceLockedError(WorkspaceError):
"""Workspace is locked by another operation."""
def __init__(self, project_id: str, holder: str | None = None) -> None:
details: dict[str, Any] = {"project_id": project_id}
if holder:
details["locked_by"] = holder
super().__init__(
f"Workspace is locked for project: {project_id}",
ErrorCode.WORKSPACE_LOCKED,
details,
)
class WorkspaceSizeExceededError(WorkspaceError):
"""Workspace size limit exceeded."""
def __init__(self, project_id: str, current_size: float, max_size: float) -> None:
super().__init__(
f"Workspace size limit exceeded for project: {project_id}",
ErrorCode.WORKSPACE_SIZE_EXCEEDED,
{
"project_id": project_id,
"current_size_gb": current_size,
"max_size_gb": max_size,
},
)
# Git Operation Errors
class GitError(GitOpsError):
"""Base exception for git operation errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.INTERNAL_ERROR,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message, code, details)
class CloneError(GitError):
"""Failed to clone repository."""
def __init__(self, repo_url: str, reason: str) -> None:
super().__init__(
f"Failed to clone repository: {reason}",
ErrorCode.CLONE_FAILED,
{"repo_url": repo_url, "reason": reason},
)
class CheckoutError(GitError):
"""Failed to checkout branch or ref."""
def __init__(self, ref: str, reason: str) -> None:
super().__init__(
f"Failed to checkout '{ref}': {reason}",
ErrorCode.CHECKOUT_FAILED,
{"ref": ref, "reason": reason},
)
class CommitError(GitError):
"""Failed to commit changes."""
def __init__(self, reason: str) -> None:
super().__init__(
f"Failed to commit: {reason}",
ErrorCode.COMMIT_FAILED,
{"reason": reason},
)
class PushError(GitError):
"""Failed to push to remote."""
def __init__(self, branch: str, reason: str) -> None:
super().__init__(
f"Failed to push branch '{branch}': {reason}",
ErrorCode.PUSH_FAILED,
{"branch": branch, "reason": reason},
)
class PullError(GitError):
"""Failed to pull from remote."""
def __init__(self, branch: str, reason: str) -> None:
super().__init__(
f"Failed to pull branch '{branch}': {reason}",
ErrorCode.PULL_FAILED,
{"branch": branch, "reason": reason},
)
class MergeConflictError(GitError):
"""Merge conflict detected."""
def __init__(self, conflicting_files: list[str]) -> None:
super().__init__(
f"Merge conflict detected in {len(conflicting_files)} files",
ErrorCode.MERGE_CONFLICT,
{"conflicting_files": conflicting_files},
)
class BranchExistsError(GitError):
"""Branch already exists."""
def __init__(self, branch_name: str) -> None:
super().__init__(
f"Branch already exists: {branch_name}",
ErrorCode.BRANCH_EXISTS,
{"branch": branch_name},
)
class BranchNotFoundError(GitError):
"""Branch does not exist."""
def __init__(self, branch_name: str) -> None:
super().__init__(
f"Branch not found: {branch_name}",
ErrorCode.BRANCH_NOT_FOUND,
{"branch": branch_name},
)
class InvalidRefError(GitError):
"""Invalid git reference."""
def __init__(self, ref: str) -> None:
super().__init__(
f"Invalid git reference: {ref}",
ErrorCode.INVALID_REF,
{"ref": ref},
)
class DirtyWorkspaceError(GitError):
"""Workspace has uncommitted changes."""
def __init__(self, modified_files: list[str]) -> None:
super().__init__(
f"Workspace has {len(modified_files)} uncommitted changes",
ErrorCode.DIRTY_WORKSPACE,
{"modified_files": modified_files[:10]}, # Limit to first 10
)
# Provider Errors
class ProviderError(GitOpsError):
"""Base exception for provider-related errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.PROVIDER_ERROR,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message, code, details)
class AuthenticationError(ProviderError):
"""Authentication with provider failed."""
def __init__(self, provider: str, reason: str) -> None:
super().__init__(
f"Authentication failed with {provider}: {reason}",
ErrorCode.PROVIDER_AUTH_FAILED,
{"provider": provider, "reason": reason},
)
class ProviderNotFoundError(ProviderError):
"""Provider not configured or recognized."""
def __init__(self, provider: str) -> None:
super().__init__(
f"Provider not found or not configured: {provider}",
ErrorCode.PROVIDER_NOT_FOUND,
{"provider": provider},
)
class PRError(ProviderError):
"""Pull request operation failed."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.PR_CREATE_FAILED,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message, code, details)
class PRNotFoundError(PRError):
"""Pull request not found."""
def __init__(self, pr_number: int, repo: str) -> None:
super().__init__(
f"Pull request #{pr_number} not found in {repo}",
ErrorCode.PR_NOT_FOUND,
{"pr_number": pr_number, "repo": repo},
)
class APIError(ProviderError):
"""Provider API error."""
def __init__(self, provider: str, status_code: int, message: str) -> None:
super().__init__(
f"{provider} API error ({status_code}): {message}",
ErrorCode.API_ERROR,
{"provider": provider, "status_code": status_code, "message": message},
)
# Credential Errors
class CredentialError(GitOpsError):
"""Base exception for credential-related errors."""
def __init__(
self,
message: str,
code: ErrorCode = ErrorCode.CREDENTIAL_ERROR,
details: dict[str, Any] | None = None,
) -> None:
super().__init__(message, code, details)
class CredentialNotFoundError(CredentialError):
"""Credential not found."""
def __init__(self, credential_type: str, identifier: str) -> None:
super().__init__(
f"{credential_type} credential not found: {identifier}",
ErrorCode.CREDENTIAL_NOT_FOUND,
{"type": credential_type, "identifier": identifier},
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,690 @@
"""
Data models for Git Operations MCP Server.
Defines data structures for git operations, workspace management,
and provider interactions.
"""
from dataclasses import dataclass, field
from datetime import UTC, datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
class FileChangeType(str, Enum):
"""Types of file changes in git."""
ADDED = "added"
MODIFIED = "modified"
DELETED = "deleted"
RENAMED = "renamed"
COPIED = "copied"
UNTRACKED = "untracked"
IGNORED = "ignored"
class MergeStrategy(str, Enum):
"""Merge strategies for pull requests."""
MERGE = "merge" # Create a merge commit
SQUASH = "squash" # Squash and merge
REBASE = "rebase" # Rebase and merge
class PRState(str, Enum):
"""Pull request states."""
OPEN = "open"
CLOSED = "closed"
MERGED = "merged"
class ProviderType(str, Enum):
"""Supported git providers."""
GITEA = "gitea"
GITHUB = "github"
GITLAB = "gitlab"
class WorkspaceState(str, Enum):
"""Workspace lifecycle states."""
INITIALIZING = "initializing"
READY = "ready"
LOCKED = "locked"
STALE = "stale"
DELETED = "deleted"
# Dataclasses for internal data structures
@dataclass
class FileChange:
"""A file change in git status."""
path: str
change_type: FileChangeType
old_path: str | None = None # For renames
additions: int = 0
deletions: int = 0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"path": self.path,
"change_type": self.change_type.value,
"old_path": self.old_path,
"additions": self.additions,
"deletions": self.deletions,
}
@dataclass
class BranchInfo:
"""Information about a git branch."""
name: str
is_current: bool = False
is_remote: bool = False
tracking_branch: str | None = None
commit_sha: str | None = None
commit_message: str | None = None
ahead: int = 0
behind: int = 0
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"name": self.name,
"is_current": self.is_current,
"is_remote": self.is_remote,
"tracking_branch": self.tracking_branch,
"commit_sha": self.commit_sha,
"commit_message": self.commit_message,
"ahead": self.ahead,
"behind": self.behind,
}
@dataclass
class CommitInfo:
"""Information about a git commit."""
sha: str
short_sha: str
message: str
author_name: str
author_email: str
authored_date: datetime
committer_name: str
committer_email: str
committed_date: datetime
parents: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"sha": self.sha,
"short_sha": self.short_sha,
"message": self.message,
"author_name": self.author_name,
"author_email": self.author_email,
"authored_date": self.authored_date.isoformat(),
"committer_name": self.committer_name,
"committer_email": self.committer_email,
"committed_date": self.committed_date.isoformat(),
"parents": self.parents,
}
@dataclass
class DiffHunk:
"""A hunk of diff content."""
old_start: int
old_lines: int
new_start: int
new_lines: int
content: str
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"old_start": self.old_start,
"old_lines": self.old_lines,
"new_start": self.new_start,
"new_lines": self.new_lines,
"content": self.content,
}
@dataclass
class FileDiff:
"""Diff for a single file."""
path: str
change_type: FileChangeType
old_path: str | None = None
hunks: list[DiffHunk] = field(default_factory=list)
additions: int = 0
deletions: int = 0
is_binary: bool = False
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"path": self.path,
"change_type": self.change_type.value,
"old_path": self.old_path,
"hunks": [h.to_dict() for h in self.hunks],
"additions": self.additions,
"deletions": self.deletions,
"is_binary": self.is_binary,
}
@dataclass
class PRInfo:
"""Information about a pull request."""
number: int
title: str
body: str
state: PRState
source_branch: str
target_branch: str
author: str
created_at: datetime
updated_at: datetime
merged_at: datetime | None = None
closed_at: datetime | None = None
url: str | None = None
labels: list[str] = field(default_factory=list)
assignees: list[str] = field(default_factory=list)
reviewers: list[str] = field(default_factory=list)
mergeable: bool | None = None
draft: bool = False
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"number": self.number,
"title": self.title,
"body": self.body,
"state": self.state.value,
"source_branch": self.source_branch,
"target_branch": self.target_branch,
"author": self.author,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"merged_at": self.merged_at.isoformat() if self.merged_at else None,
"closed_at": self.closed_at.isoformat() if self.closed_at else None,
"url": self.url,
"labels": self.labels,
"assignees": self.assignees,
"reviewers": self.reviewers,
"mergeable": self.mergeable,
"draft": self.draft,
}
@dataclass
class WorkspaceInfo:
"""Information about a project workspace."""
project_id: str
path: str
state: WorkspaceState
repo_url: str | None = None
current_branch: str | None = None
last_accessed: datetime = field(default_factory=lambda: datetime.now(UTC))
size_bytes: int = 0
lock_holder: str | None = None
lock_expires: datetime | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary."""
return {
"project_id": self.project_id,
"path": self.path,
"state": self.state.value,
"repo_url": self.repo_url,
"current_branch": self.current_branch,
"last_accessed": self.last_accessed.isoformat(),
"size_bytes": self.size_bytes,
"lock_holder": self.lock_holder,
"lock_expires": self.lock_expires.isoformat()
if self.lock_expires
else None,
}
# Pydantic Request/Response Models
class CloneRequest(BaseModel):
"""Request to clone a repository."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
repo_url: str = Field(..., description="Repository URL to clone")
branch: str | None = Field(
default=None, description="Branch to checkout after clone"
)
depth: int | None = Field(
default=None, ge=1, description="Shallow clone depth (None = full clone)"
)
class CloneResult(BaseModel):
"""Result of a clone operation."""
success: bool = Field(..., description="Whether clone succeeded")
project_id: str = Field(..., description="Project ID")
workspace_path: str = Field(..., description="Path to cloned workspace")
branch: str = Field(..., description="Current branch after clone")
commit_sha: str = Field(..., description="HEAD commit SHA")
error: str | None = Field(default=None, description="Error message if failed")
class StatusRequest(BaseModel):
"""Request for git status."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
include_untracked: bool = Field(default=True, description="Include untracked files")
class StatusResult(BaseModel):
"""Result of a status operation."""
project_id: str = Field(..., description="Project ID")
branch: str = Field(..., description="Current branch")
commit_sha: str = Field(..., description="HEAD commit SHA")
is_clean: bool = Field(..., description="Whether working tree is clean")
staged: list[dict[str, Any]] = Field(
default_factory=list, description="Staged changes"
)
unstaged: list[dict[str, Any]] = Field(
default_factory=list, description="Unstaged changes"
)
untracked: list[str] = Field(default_factory=list, description="Untracked files")
ahead: int = Field(default=0, description="Commits ahead of upstream")
behind: int = Field(default=0, description="Commits behind upstream")
class BranchRequest(BaseModel):
"""Request for branch operations."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
branch_name: str = Field(..., description="Branch name")
from_ref: str | None = Field(
default=None, description="Reference to create branch from"
)
checkout: bool = Field(default=True, description="Checkout after creation")
class BranchResult(BaseModel):
"""Result of a branch operation."""
success: bool = Field(..., description="Whether operation succeeded")
branch: str = Field(..., description="Branch name")
commit_sha: str | None = Field(default=None, description="HEAD commit SHA")
is_current: bool = Field(default=False, description="Whether branch is checked out")
error: str | None = Field(default=None, description="Error message if failed")
class ListBranchesRequest(BaseModel):
"""Request to list branches."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
include_remote: bool = Field(default=False, description="Include remote branches")
class ListBranchesResult(BaseModel):
"""Result of listing branches."""
project_id: str = Field(..., description="Project ID")
current_branch: str = Field(..., description="Currently checked out branch")
local_branches: list[dict[str, Any]] = Field(
default_factory=list, description="Local branches"
)
remote_branches: list[dict[str, Any]] = Field(
default_factory=list, description="Remote branches"
)
class CheckoutRequest(BaseModel):
"""Request to checkout a branch or ref."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
ref: str = Field(..., description="Branch, tag, or commit to checkout")
create_branch: bool = Field(default=False, description="Create new branch")
force: bool = Field(default=False, description="Force checkout (discard changes)")
class CheckoutResult(BaseModel):
"""Result of a checkout operation."""
success: bool = Field(..., description="Whether checkout succeeded")
ref: str = Field(..., description="Checked out reference")
commit_sha: str | None = Field(default=None, description="HEAD commit SHA")
error: str | None = Field(default=None, description="Error message if failed")
class CommitRequest(BaseModel):
"""Request to create a commit."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
message: str = Field(..., description="Commit message")
files: list[str] | None = Field(
default=None, description="Files to commit (None = all staged)"
)
author_name: str | None = Field(default=None, description="Author name override")
author_email: str | None = Field(default=None, description="Author email override")
allow_empty: bool = Field(default=False, description="Allow empty commit")
class CommitResult(BaseModel):
"""Result of a commit operation."""
success: bool = Field(..., description="Whether commit succeeded")
commit_sha: str | None = Field(default=None, description="New commit SHA")
short_sha: str | None = Field(default=None, description="Short commit SHA")
message: str | None = Field(default=None, description="Commit message")
files_changed: int = Field(default=0, description="Number of files changed")
insertions: int = Field(default=0, description="Lines added")
deletions: int = Field(default=0, description="Lines removed")
error: str | None = Field(default=None, description="Error message if failed")
class PushRequest(BaseModel):
"""Request to push to remote."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
branch: str | None = Field(
default=None, description="Branch to push (None = current)"
)
remote: str = Field(default="origin", description="Remote name")
force: bool = Field(default=False, description="Force push")
set_upstream: bool = Field(default=True, description="Set upstream tracking")
class PushResult(BaseModel):
"""Result of a push operation."""
success: bool = Field(..., description="Whether push succeeded")
branch: str = Field(..., description="Pushed branch")
remote: str = Field(..., description="Remote name")
commits_pushed: int = Field(default=0, description="Number of commits pushed")
error: str | None = Field(default=None, description="Error message if failed")
class PullRequest(BaseModel):
"""Request to pull from remote."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
branch: str | None = Field(
default=None, description="Branch to pull (None = current)"
)
remote: str = Field(default="origin", description="Remote name")
rebase: bool = Field(default=False, description="Rebase instead of merge")
class PullResult(BaseModel):
"""Result of a pull operation."""
success: bool = Field(..., description="Whether pull succeeded")
branch: str = Field(..., description="Pulled branch")
commits_received: int = Field(default=0, description="New commits received")
fast_forward: bool = Field(default=False, description="Was fast-forward")
conflicts: list[str] = Field(
default_factory=list, description="Conflicting files if any"
)
error: str | None = Field(default=None, description="Error message if failed")
class DiffRequest(BaseModel):
"""Request for diff."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
base: str | None = Field(
default=None, description="Base reference (None = working tree)"
)
head: str | None = Field(default=None, description="Head reference (None = HEAD)")
files: list[str] | None = Field(default=None, description="Specific files to diff")
context_lines: int = Field(default=3, ge=0, description="Context lines")
class DiffResult(BaseModel):
"""Result of a diff operation."""
project_id: str = Field(..., description="Project ID")
base: str | None = Field(default=None, description="Base reference")
head: str | None = Field(default=None, description="Head reference")
files: list[dict[str, Any]] = Field(default_factory=list, description="File diffs")
total_additions: int = Field(default=0, description="Total lines added")
total_deletions: int = Field(default=0, description="Total lines removed")
files_changed: int = Field(default=0, description="Number of files changed")
class LogRequest(BaseModel):
"""Request for commit log."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
ref: str | None = Field(default=None, description="Reference to start from")
limit: int = Field(default=20, ge=1, le=100, description="Max commits to return")
skip: int = Field(default=0, ge=0, description="Commits to skip")
path: str | None = Field(default=None, description="Filter by path")
class LogResult(BaseModel):
"""Result of a log operation."""
project_id: str = Field(..., description="Project ID")
commits: list[dict[str, Any]] = Field(
default_factory=list, description="Commit history"
)
total_commits: int = Field(default=0, description="Total commits in range")
# PR Operations
class CreatePRRequest(BaseModel):
"""Request to create a pull request."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
title: str = Field(..., description="PR title")
body: str = Field(default="", description="PR description")
source_branch: str = Field(..., description="Source branch")
target_branch: str = Field(default="main", description="Target branch")
draft: bool = Field(default=False, description="Create as draft")
labels: list[str] = Field(default_factory=list, description="Labels to add")
assignees: list[str] = Field(default_factory=list, description="Assignees")
reviewers: list[str] = Field(default_factory=list, description="Reviewers")
class CreatePRResult(BaseModel):
"""Result of creating a pull request."""
success: bool = Field(..., description="Whether creation succeeded")
pr_number: int | None = Field(default=None, description="PR number")
pr_url: str | None = Field(default=None, description="PR URL")
error: str | None = Field(default=None, description="Error message if failed")
class GetPRRequest(BaseModel):
"""Request to get a pull request."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
pr_number: int = Field(..., description="PR number")
class GetPRResult(BaseModel):
"""Result of getting a pull request."""
success: bool = Field(..., description="Whether fetch succeeded")
pr: dict[str, Any] | None = Field(default=None, description="PR info")
error: str | None = Field(default=None, description="Error message if failed")
class ListPRsRequest(BaseModel):
"""Request to list pull requests."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
state: PRState | None = Field(default=None, description="Filter by state")
author: str | None = Field(default=None, description="Filter by author")
limit: int = Field(default=20, ge=1, le=100, description="Max PRs to return")
class ListPRsResult(BaseModel):
"""Result of listing pull requests."""
success: bool = Field(..., description="Whether list succeeded")
pull_requests: list[dict[str, Any]] = Field(default_factory=list, description="PRs")
total_count: int = Field(default=0, description="Total matching PRs")
error: str | None = Field(default=None, description="Error message if failed")
class MergePRRequest(BaseModel):
"""Request to merge a pull request."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
pr_number: int = Field(..., description="PR number")
merge_strategy: MergeStrategy = Field(
default=MergeStrategy.MERGE, description="Merge strategy"
)
commit_message: str | None = Field(
default=None, description="Custom merge commit message"
)
delete_branch: bool = Field(
default=True, description="Delete source branch after merge"
)
class MergePRResult(BaseModel):
"""Result of merging a pull request."""
success: bool = Field(..., description="Whether merge succeeded")
merge_commit_sha: str | None = Field(default=None, description="Merge commit SHA")
branch_deleted: bool = Field(
default=False, description="Whether branch was deleted"
)
error: str | None = Field(default=None, description="Error message if failed")
class UpdatePRRequest(BaseModel):
"""Request to update a pull request."""
project_id: str = Field(..., description="Project ID for scoping")
agent_id: str = Field(..., description="Agent ID making the request")
pr_number: int = Field(..., description="PR number")
title: str | None = Field(default=None, description="New title")
body: str | None = Field(default=None, description="New description")
state: PRState | None = Field(default=None, description="New state")
labels: list[str] | None = Field(default=None, description="Replace labels")
assignees: list[str] | None = Field(default=None, description="Replace assignees")
class UpdatePRResult(BaseModel):
"""Result of updating a pull request."""
success: bool = Field(..., description="Whether update succeeded")
pr: dict[str, Any] | None = Field(default=None, description="Updated PR info")
error: str | None = Field(default=None, description="Error message if failed")
# Workspace Operations
class GetWorkspaceRequest(BaseModel):
"""Request to get or create workspace."""
project_id: str = Field(..., description="Project ID")
agent_id: str = Field(..., description="Agent ID making the request")
class GetWorkspaceResult(BaseModel):
"""Result of getting workspace."""
success: bool = Field(..., description="Whether operation succeeded")
workspace: dict[str, Any] | None = Field(default=None, description="Workspace info")
error: str | None = Field(default=None, description="Error message if failed")
class LockWorkspaceRequest(BaseModel):
"""Request to lock a workspace."""
project_id: str = Field(..., description="Project ID")
agent_id: str = Field(..., description="Agent ID requesting lock")
timeout: int = Field(
default=300, ge=10, le=3600, description="Lock timeout seconds"
)
class LockWorkspaceResult(BaseModel):
"""Result of locking workspace."""
success: bool = Field(..., description="Whether lock acquired")
lock_holder: str | None = Field(default=None, description="Current lock holder")
lock_expires: str | None = Field(
default=None, description="Lock expiry ISO timestamp"
)
error: str | None = Field(default=None, description="Error message if failed")
class UnlockWorkspaceRequest(BaseModel):
"""Request to unlock a workspace."""
project_id: str = Field(..., description="Project ID")
agent_id: str = Field(..., description="Agent ID releasing lock")
force: bool = Field(default=False, description="Force unlock (admin only)")
class UnlockWorkspaceResult(BaseModel):
"""Result of unlocking workspace."""
success: bool = Field(..., description="Whether unlock succeeded")
error: str | None = Field(default=None, description="Error message if failed")
# Health and Status
class HealthStatus(BaseModel):
"""Health status response."""
status: str = Field(..., description="Health status")
version: str = Field(..., description="Server version")
workspace_count: int = Field(default=0, description="Active workspaces")
gitea_connected: bool = Field(default=False, description="Gitea connectivity")
github_connected: bool = Field(default=False, description="GitHub connectivity")
gitlab_connected: bool = Field(default=False, description="GitLab connectivity")
redis_connected: bool = Field(default=False, description="Redis connectivity")
class ProviderStatus(BaseModel):
"""Provider connection status."""
provider: str = Field(..., description="Provider name")
connected: bool = Field(..., description="Connection status")
url: str | None = Field(default=None, description="Provider URL")
user: str | None = Field(default=None, description="Authenticated user")
error: str | None = Field(default=None, description="Error if not connected")

View File

@@ -0,0 +1,11 @@
"""
Git provider implementations.
Provides adapters for different git hosting platforms (Gitea, GitHub, GitLab).
"""
from .base import BaseProvider
from .gitea import GiteaProvider
from .github import GitHubProvider
__all__ = ["BaseProvider", "GiteaProvider", "GitHubProvider"]

View File

@@ -0,0 +1,376 @@
"""
Base provider interface for git hosting platforms.
Defines the abstract interface that all git providers must implement.
"""
from abc import ABC, abstractmethod
from typing import Any
from models import (
CreatePRResult,
GetPRResult,
ListPRsResult,
MergePRResult,
MergeStrategy,
PRState,
UpdatePRResult,
)
class BaseProvider(ABC):
"""
Abstract base class for git hosting providers.
All providers (Gitea, GitHub, GitLab) must implement this interface.
"""
@property
@abstractmethod
def name(self) -> str:
"""Return the provider name (e.g., 'gitea', 'github')."""
...
@abstractmethod
async def is_connected(self) -> bool:
"""Check if the provider is connected and authenticated."""
...
@abstractmethod
async def get_authenticated_user(self) -> str | None:
"""Get the username of the authenticated user."""
...
# Repository operations
@abstractmethod
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
"""
Get repository information.
Args:
owner: Repository owner/organization
repo: Repository name
Returns:
Repository info dict
"""
...
@abstractmethod
async def get_default_branch(self, owner: str, repo: str) -> str:
"""
Get the default branch for a repository.
Args:
owner: Repository owner/organization
repo: Repository name
Returns:
Default branch name
"""
...
# Pull Request operations
@abstractmethod
async def create_pr(
self,
owner: str,
repo: str,
title: str,
body: str,
source_branch: str,
target_branch: str,
draft: bool = False,
labels: list[str] | None = None,
assignees: list[str] | None = None,
reviewers: list[str] | None = None,
) -> CreatePRResult:
"""
Create a pull request.
Args:
owner: Repository owner
repo: Repository name
title: PR title
body: PR description
source_branch: Source branch name
target_branch: Target branch name
draft: Whether to create as draft
labels: Labels to add
assignees: Users to assign
reviewers: Users to request review from
Returns:
CreatePRResult with PR number and URL
"""
...
@abstractmethod
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
"""
Get a pull request by number.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
Returns:
GetPRResult with PR details
"""
...
@abstractmethod
async def list_prs(
self,
owner: str,
repo: str,
state: PRState | None = None,
author: str | None = None,
limit: int = 20,
) -> ListPRsResult:
"""
List pull requests.
Args:
owner: Repository owner
repo: Repository name
state: Filter by state (open, closed, merged)
author: Filter by author
limit: Maximum PRs to return
Returns:
ListPRsResult with list of PRs
"""
...
@abstractmethod
async def merge_pr(
self,
owner: str,
repo: str,
pr_number: int,
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
commit_message: str | None = None,
delete_branch: bool = True,
) -> MergePRResult:
"""
Merge a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
merge_strategy: Merge strategy to use
commit_message: Custom merge commit message
delete_branch: Whether to delete source branch
Returns:
MergePRResult with merge status
"""
...
@abstractmethod
async def update_pr(
self,
owner: str,
repo: str,
pr_number: int,
title: str | None = None,
body: str | None = None,
state: PRState | None = None,
labels: list[str] | None = None,
assignees: list[str] | None = None,
) -> UpdatePRResult:
"""
Update a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
title: New title
body: New description
state: New state (open, closed)
labels: Replace labels
assignees: Replace assignees
Returns:
UpdatePRResult with updated PR info
"""
...
@abstractmethod
async def close_pr(self, owner: str, repo: str, pr_number: int) -> UpdatePRResult:
"""
Close a pull request without merging.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
Returns:
UpdatePRResult with updated PR info
"""
...
# Branch operations via API (for operations that need to bypass local git)
@abstractmethod
async def delete_remote_branch(self, owner: str, repo: str, branch: str) -> bool:
"""
Delete a remote branch via API.
Args:
owner: Repository owner
repo: Repository name
branch: Branch name to delete
Returns:
True if deleted, False otherwise
"""
...
@abstractmethod
async def get_branch(
self, owner: str, repo: str, branch: str
) -> dict[str, Any] | None:
"""
Get branch information via API.
Args:
owner: Repository owner
repo: Repository name
branch: Branch name
Returns:
Branch info dict or None if not found
"""
...
# Comment operations
@abstractmethod
async def add_pr_comment(
self, owner: str, repo: str, pr_number: int, body: str
) -> dict[str, Any]:
"""
Add a comment to a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
body: Comment body
Returns:
Created comment info
"""
...
@abstractmethod
async def list_pr_comments(
self, owner: str, repo: str, pr_number: int
) -> list[dict[str, Any]]:
"""
List comments on a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
Returns:
List of comments
"""
...
# Label operations
@abstractmethod
async def add_labels(
self, owner: str, repo: str, pr_number: int, labels: list[str]
) -> list[str]:
"""
Add labels to a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
labels: Labels to add
Returns:
Updated list of labels
"""
...
@abstractmethod
async def remove_label(
self, owner: str, repo: str, pr_number: int, label: str
) -> list[str]:
"""
Remove a label from a pull request.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
label: Label to remove
Returns:
Updated list of labels
"""
...
# Reviewer operations
@abstractmethod
async def request_review(
self, owner: str, repo: str, pr_number: int, reviewers: list[str]
) -> list[str]:
"""
Request review from users.
Args:
owner: Repository owner
repo: Repository name
pr_number: Pull request number
reviewers: Usernames to request review from
Returns:
List of reviewers requested
"""
...
# Utility methods
def parse_repo_url(self, repo_url: str) -> tuple[str, str]:
"""
Parse repository URL to extract owner and repo name.
Args:
repo_url: Repository URL (HTTPS or SSH)
Returns:
Tuple of (owner, repo)
Raises:
ValueError: If URL cannot be parsed
"""
import re
# Handle SSH URLs: git@host:owner/repo.git
ssh_match = re.match(r"git@[^:]+:([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
if ssh_match:
return ssh_match.group(1), ssh_match.group(2)
# Handle HTTPS URLs: https://host/owner/repo.git
https_match = re.match(r"https?://[^/]+/([^/]+)/([^/]+?)(?:\.git)?$", repo_url)
if https_match:
return https_match.group(1), https_match.group(2)
raise ValueError(f"Unable to parse repository URL: {repo_url}")

View File

@@ -0,0 +1,723 @@
"""
Gitea provider implementation.
Implements the BaseProvider interface for Gitea API operations.
"""
import logging
from datetime import UTC, datetime
from typing import Any
import httpx
from config import Settings, get_settings
from exceptions import (
APIError,
AuthenticationError,
PRNotFoundError,
)
from models import (
CreatePRResult,
GetPRResult,
ListPRsResult,
MergePRResult,
MergeStrategy,
PRInfo,
PRState,
UpdatePRResult,
)
from .base import BaseProvider
logger = logging.getLogger(__name__)
class GiteaProvider(BaseProvider):
"""
Gitea API provider implementation.
Supports all PR operations, branch operations, and repository queries.
"""
def __init__(
self,
base_url: str | None = None,
token: str | None = None,
settings: Settings | None = None,
) -> None:
"""
Initialize Gitea provider.
Args:
base_url: Gitea server URL (e.g., https://gitea.example.com)
token: API token
settings: Optional settings override
"""
self.settings = settings or get_settings()
self.base_url = (base_url or self.settings.gitea_base_url).rstrip("/")
self.token = token or self.settings.gitea_token
self._client: httpx.AsyncClient | None = None
self._user: str | None = None
@property
def name(self) -> str:
"""Return the provider name."""
return "gitea"
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if self.token:
headers["Authorization"] = f"token {self.token}"
self._client = httpx.AsyncClient(
base_url=f"{self.base_url}/api/v1",
headers=headers,
timeout=30.0,
)
return self._client
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def _request(
self,
method: str,
path: str,
**kwargs: Any,
) -> Any:
"""
Make an API request.
Args:
method: HTTP method
path: API path
**kwargs: Additional request arguments
Returns:
Parsed JSON response
Raises:
APIError: On API errors
AuthenticationError: On auth failures
"""
client = await self._get_client()
try:
response = await client.request(method, path, **kwargs)
if response.status_code == 401:
raise AuthenticationError("gitea", "Invalid or expired token")
if response.status_code == 403:
raise AuthenticationError(
"gitea", "Insufficient permissions for this operation"
)
if response.status_code == 404:
return None
if response.status_code >= 400:
error_msg = response.text
try:
error_data = response.json()
error_msg = error_data.get("message", error_msg)
except Exception:
pass
raise APIError("gitea", response.status_code, error_msg)
if response.status_code == 204:
return None
return response.json()
except httpx.RequestError as e:
raise APIError("gitea", 0, f"Request failed: {e}")
async def is_connected(self) -> bool:
"""Check if connected to Gitea."""
if not self.base_url or not self.token:
return False
try:
result = await self._request("GET", "/user")
return result is not None
except Exception:
return False
async def get_authenticated_user(self) -> str | None:
"""Get the authenticated user's username."""
if self._user:
return self._user
try:
result = await self._request("GET", "/user")
if result:
self._user = result.get("login") or result.get("username")
return self._user
except Exception:
pass
return None
# Repository operations
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
"""Get repository information."""
result = await self._request("GET", f"/repos/{owner}/{repo}")
if result is None:
raise APIError("gitea", 404, f"Repository not found: {owner}/{repo}")
return result
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch for a repository."""
repo_info = await self.get_repo_info(owner, repo)
return repo_info.get("default_branch", "main")
# Pull Request operations
async def create_pr(
self,
owner: str,
repo: str,
title: str,
body: str,
source_branch: str,
target_branch: str,
draft: bool = False,
labels: list[str] | None = None,
assignees: list[str] | None = None,
reviewers: list[str] | None = None,
) -> CreatePRResult:
"""Create a pull request."""
try:
data: dict[str, Any] = {
"title": title,
"body": body,
"head": source_branch,
"base": target_branch,
}
# Note: Gitea doesn't have draft PR support in all versions
# Draft support was added in Gitea 1.14+
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls",
json=data,
)
if result is None:
return CreatePRResult(
success=False,
error="Failed to create pull request",
)
pr_number = result["number"]
# Add labels if specified
if labels:
await self.add_labels(owner, repo, pr_number, labels)
# Add assignees if specified (via issue update)
if assignees:
await self._request(
"PATCH",
f"/repos/{owner}/{repo}/issues/{pr_number}",
json={"assignees": assignees},
)
# Request reviewers if specified
if reviewers:
await self.request_review(owner, repo, pr_number, reviewers)
return CreatePRResult(
success=True,
pr_number=pr_number,
pr_url=result.get("html_url"),
)
except APIError as e:
return CreatePRResult(
success=False,
error=str(e),
)
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
"""Get a pull request by number."""
try:
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/pulls/{pr_number}",
)
if result is None:
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
pr_info = self._parse_pr(result)
return GetPRResult(
success=True,
pr=pr_info.to_dict(),
)
except PRNotFoundError:
return GetPRResult(
success=False,
error=f"Pull request #{pr_number} not found",
)
except APIError as e:
return GetPRResult(
success=False,
error=str(e),
)
async def list_prs(
self,
owner: str,
repo: str,
state: PRState | None = None,
author: str | None = None,
limit: int = 20,
) -> ListPRsResult:
"""List pull requests."""
try:
params: dict[str, Any] = {
"limit": limit,
}
if state:
# Gitea uses different state names
if state == PRState.OPEN:
params["state"] = "open"
elif state == PRState.CLOSED or state == PRState.MERGED:
params["state"] = "closed"
else:
params["state"] = "all"
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/pulls",
params=params,
)
if result is None:
return ListPRsResult(
success=True,
pull_requests=[],
total_count=0,
)
prs = []
for pr_data in result:
# Filter by author if specified
if author:
pr_author = pr_data.get("user", {}).get("login", "")
if pr_author.lower() != author.lower():
continue
# Filter merged PRs if looking specifically for merged
if state == PRState.MERGED:
if not pr_data.get("merged"):
continue
pr_info = self._parse_pr(pr_data)
prs.append(pr_info.to_dict())
return ListPRsResult(
success=True,
pull_requests=prs,
total_count=len(prs),
)
except APIError as e:
return ListPRsResult(
success=False,
error=str(e),
)
async def merge_pr(
self,
owner: str,
repo: str,
pr_number: int,
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
commit_message: str | None = None,
delete_branch: bool = True,
) -> MergePRResult:
"""Merge a pull request."""
try:
# Map merge strategy to Gitea's "Do" values
do_map = {
MergeStrategy.MERGE: "merge",
MergeStrategy.SQUASH: "squash",
MergeStrategy.REBASE: "rebase",
}
data: dict[str, Any] = {
"Do": do_map[merge_strategy],
"delete_branch_after_merge": delete_branch,
}
if commit_message:
data["MergeTitleField"] = commit_message.split("\n")[0]
if "\n" in commit_message:
data["MergeMessageField"] = "\n".join(
commit_message.split("\n")[1:]
)
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
json=data,
)
if result is None:
# Check if PR was actually merged
pr_result = await self.get_pr(owner, repo, pr_number)
if pr_result.success and pr_result.pr:
if pr_result.pr.get("state") == "merged":
return MergePRResult(
success=True,
branch_deleted=delete_branch,
)
return MergePRResult(
success=False,
error="Failed to merge pull request",
)
return MergePRResult(
success=True,
merge_commit_sha=result.get("sha"),
branch_deleted=delete_branch,
)
except APIError as e:
return MergePRResult(
success=False,
error=str(e),
)
async def update_pr(
self,
owner: str,
repo: str,
pr_number: int,
title: str | None = None,
body: str | None = None,
state: PRState | None = None,
labels: list[str] | None = None,
assignees: list[str] | None = None,
) -> UpdatePRResult:
"""Update a pull request."""
try:
data: dict[str, Any] = {}
if title is not None:
data["title"] = title
if body is not None:
data["body"] = body
if state is not None:
if state == PRState.OPEN:
data["state"] = "open"
elif state == PRState.CLOSED:
data["state"] = "closed"
# Update PR if there's data
if data:
await self._request(
"PATCH",
f"/repos/{owner}/{repo}/pulls/{pr_number}",
json=data,
)
# Update labels via issue endpoint
if labels is not None:
# First clear existing labels
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
)
# Then add new labels
if labels:
await self.add_labels(owner, repo, pr_number, labels)
# Update assignees via issue endpoint
if assignees is not None:
await self._request(
"PATCH",
f"/repos/{owner}/{repo}/issues/{pr_number}",
json={"assignees": assignees},
)
# Fetch updated PR
result = await self.get_pr(owner, repo, pr_number)
return UpdatePRResult(
success=result.success,
pr=result.pr,
error=result.error,
)
except APIError as e:
return UpdatePRResult(
success=False,
error=str(e),
)
async def close_pr(
self,
owner: str,
repo: str,
pr_number: int,
) -> UpdatePRResult:
"""Close a pull request without merging."""
return await self.update_pr(
owner,
repo,
pr_number,
state=PRState.CLOSED,
)
# Branch operations
async def delete_remote_branch(
self,
owner: str,
repo: str,
branch: str,
) -> bool:
"""Delete a remote branch."""
try:
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/branches/{branch}",
)
return True
except APIError:
return False
async def get_branch(
self,
owner: str,
repo: str,
branch: str,
) -> dict[str, Any] | None:
"""Get branch information."""
return await self._request(
"GET",
f"/repos/{owner}/{repo}/branches/{branch}",
)
# Comment operations
async def add_pr_comment(
self,
owner: str,
repo: str,
pr_number: int,
body: str,
) -> dict[str, Any]:
"""Add a comment to a pull request."""
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
json={"body": body},
)
return result or {}
async def list_pr_comments(
self,
owner: str,
repo: str,
pr_number: int,
) -> list[dict[str, Any]]:
"""List comments on a pull request."""
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
)
return result or []
# Label operations
async def add_labels(
self,
owner: str,
repo: str,
pr_number: int,
labels: list[str],
) -> list[str]:
"""Add labels to a pull request."""
# First, get or create label IDs
label_ids = []
for label_name in labels:
label_id = await self._get_or_create_label(owner, repo, label_name)
if label_id:
label_ids.append(label_id)
if label_ids:
await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
json={"labels": label_ids},
)
# Return current labels
issue = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}",
)
if issue:
return [lbl["name"] for lbl in issue.get("labels", [])]
return labels
async def remove_label(
self,
owner: str,
repo: str,
pr_number: int,
label: str,
) -> list[str]:
"""Remove a label from a pull request."""
# Get label ID
label_info = await self._request(
"GET",
f"/repos/{owner}/{repo}/labels?name={label}",
)
if label_info and len(label_info) > 0:
label_id = label_info[0]["id"]
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label_id}",
)
# Return remaining labels
issue = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}",
)
if issue:
return [lbl["name"] for lbl in issue.get("labels", [])]
return []
async def _get_or_create_label(
self,
owner: str,
repo: str,
label_name: str,
) -> int | None:
"""Get or create a label and return its ID."""
# Try to find existing label
labels = await self._request(
"GET",
f"/repos/{owner}/{repo}/labels",
)
if labels:
for label in labels:
if label["name"].lower() == label_name.lower():
return label["id"]
# Create new label with default color
try:
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/labels",
json={
"name": label_name,
"color": "#3B82F6", # Default blue
},
)
if result:
return result["id"]
except APIError:
pass
return None
# Reviewer operations
async def request_review(
self,
owner: str,
repo: str,
pr_number: int,
reviewers: list[str],
) -> list[str]:
"""Request review from users."""
await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
json={"reviewers": reviewers},
)
return reviewers
# Helper methods
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
"""Parse PR API response into PRInfo."""
# Parse dates
created_at = self._parse_datetime(data.get("created_at"))
updated_at = self._parse_datetime(data.get("updated_at"))
merged_at = self._parse_datetime(data.get("merged_at"))
closed_at = self._parse_datetime(data.get("closed_at"))
# Determine state
if data.get("merged"):
state = PRState.MERGED
elif data.get("state") == "closed":
state = PRState.CLOSED
else:
state = PRState.OPEN
# Extract labels
labels = [lbl["name"] for lbl in data.get("labels", [])]
# Extract assignees
assignees = [a["login"] for a in data.get("assignees", [])]
# Extract reviewers
reviewers = []
if "requested_reviewers" in data:
reviewers = [r["login"] for r in data["requested_reviewers"]]
return PRInfo(
number=data["number"],
title=data["title"],
body=data.get("body", ""),
state=state,
source_branch=data.get("head", {}).get("ref", ""),
target_branch=data.get("base", {}).get("ref", ""),
author=data.get("user", {}).get("login", ""),
created_at=created_at,
updated_at=updated_at,
merged_at=merged_at,
closed_at=closed_at,
url=data.get("html_url"),
labels=labels,
assignees=assignees,
reviewers=reviewers,
mergeable=data.get("mergeable"),
draft=data.get("draft", False),
)
def _parse_datetime(self, value: str | None) -> datetime:
"""Parse datetime string from API."""
if not value:
return datetime.now(UTC)
try:
# Handle Gitea's datetime format
if value.endswith("Z"):
value = value[:-1] + "+00:00"
return datetime.fromisoformat(value)
except ValueError:
return datetime.now(UTC)

View File

@@ -0,0 +1,675 @@
"""
GitHub provider implementation.
Implements the BaseProvider interface for GitHub API operations.
"""
import logging
from datetime import UTC, datetime
from typing import Any
import httpx
from config import Settings, get_settings
from exceptions import (
APIError,
AuthenticationError,
PRNotFoundError,
)
from models import (
CreatePRResult,
GetPRResult,
ListPRsResult,
MergePRResult,
MergeStrategy,
PRInfo,
PRState,
UpdatePRResult,
)
from .base import BaseProvider
logger = logging.getLogger(__name__)
class GitHubProvider(BaseProvider):
"""
GitHub API provider implementation.
Supports all PR operations, branch operations, and repository queries.
"""
def __init__(
self,
token: str | None = None,
settings: Settings | None = None,
) -> None:
"""
Initialize GitHub provider.
Args:
token: GitHub personal access token or fine-grained token
settings: Optional settings override
"""
self.settings = settings or get_settings()
self.token = token or self.settings.github_token
self._client: httpx.AsyncClient | None = None
self._user: str | None = None
@property
def name(self) -> str:
"""Return the provider name."""
return "github"
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None:
headers = {
"Accept": "application/vnd.github+json",
"X-GitHub-Api-Version": "2022-11-28",
}
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
self._client = httpx.AsyncClient(
base_url="https://api.github.com",
headers=headers,
timeout=30.0,
)
return self._client
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def _request(
self,
method: str,
path: str,
**kwargs: Any,
) -> Any:
"""
Make an API request.
Args:
method: HTTP method
path: API path
**kwargs: Additional request arguments
Returns:
Parsed JSON response
Raises:
APIError: On API errors
AuthenticationError: On auth failures
"""
client = await self._get_client()
try:
response = await client.request(method, path, **kwargs)
if response.status_code == 401:
raise AuthenticationError("github", "Invalid or expired token")
if response.status_code == 403:
# Check for rate limiting
if "rate limit" in response.text.lower():
raise APIError("github", 403, "GitHub API rate limit exceeded")
raise AuthenticationError(
"github", "Insufficient permissions for this operation"
)
if response.status_code == 404:
return None
if response.status_code >= 400:
error_msg = response.text
try:
error_data = response.json()
error_msg = error_data.get("message", error_msg)
except Exception:
pass
raise APIError("github", response.status_code, error_msg)
if response.status_code == 204:
return None
return response.json()
except httpx.RequestError as e:
raise APIError("github", 0, f"Request failed: {e}")
async def is_connected(self) -> bool:
"""Check if connected to GitHub."""
if not self.token:
return False
try:
result = await self._request("GET", "/user")
return result is not None
except Exception:
return False
async def get_authenticated_user(self) -> str | None:
"""Get the authenticated user's username."""
if self._user:
return self._user
try:
result = await self._request("GET", "/user")
if result:
self._user = result.get("login")
return self._user
except Exception:
pass
return None
# Repository operations
async def get_repo_info(self, owner: str, repo: str) -> dict[str, Any]:
"""Get repository information."""
result = await self._request("GET", f"/repos/{owner}/{repo}")
if result is None:
raise APIError("github", 404, f"Repository not found: {owner}/{repo}")
return result
async def get_default_branch(self, owner: str, repo: str) -> str:
"""Get the default branch for a repository."""
repo_info = await self.get_repo_info(owner, repo)
return repo_info.get("default_branch", "main")
# Pull Request operations
async def create_pr(
self,
owner: str,
repo: str,
title: str,
body: str,
source_branch: str,
target_branch: str,
draft: bool = False,
labels: list[str] | None = None,
assignees: list[str] | None = None,
reviewers: list[str] | None = None,
) -> CreatePRResult:
"""Create a pull request."""
try:
data: dict[str, Any] = {
"title": title,
"body": body,
"head": source_branch,
"base": target_branch,
"draft": draft,
}
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls",
json=data,
)
if result is None:
return CreatePRResult(
success=False,
error="Failed to create pull request",
)
pr_number = result["number"]
# Add labels if specified
if labels:
await self.add_labels(owner, repo, pr_number, labels)
# Add assignees if specified
if assignees:
await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
json={"assignees": assignees},
)
# Request reviewers if specified
if reviewers:
await self.request_review(owner, repo, pr_number, reviewers)
return CreatePRResult(
success=True,
pr_number=pr_number,
pr_url=result.get("html_url"),
)
except APIError as e:
return CreatePRResult(
success=False,
error=str(e),
)
async def get_pr(self, owner: str, repo: str, pr_number: int) -> GetPRResult:
"""Get a pull request by number."""
try:
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/pulls/{pr_number}",
)
if result is None:
raise PRNotFoundError(pr_number, f"{owner}/{repo}")
pr_info = self._parse_pr(result)
return GetPRResult(
success=True,
pr=pr_info.to_dict(),
)
except PRNotFoundError:
return GetPRResult(
success=False,
error=f"Pull request #{pr_number} not found",
)
except APIError as e:
return GetPRResult(
success=False,
error=str(e),
)
async def list_prs(
self,
owner: str,
repo: str,
state: PRState | None = None,
author: str | None = None,
limit: int = 20,
) -> ListPRsResult:
"""List pull requests."""
try:
params: dict[str, Any] = {
"per_page": min(limit, 100), # GitHub max is 100
}
if state:
# GitHub uses 'state' for open/closed only
# Merged PRs are closed PRs with merged_at set
if state == PRState.OPEN:
params["state"] = "open"
elif state in (PRState.CLOSED, PRState.MERGED):
params["state"] = "closed"
else:
params["state"] = "all"
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/pulls",
params=params,
)
if result is None:
return ListPRsResult(
success=True,
pull_requests=[],
total_count=0,
)
prs = []
for pr_data in result:
# Filter by author if specified
if author:
pr_author = pr_data.get("user", {}).get("login", "")
if pr_author.lower() != author.lower():
continue
# Filter merged PRs if looking specifically for merged
if state == PRState.MERGED:
if not pr_data.get("merged_at"):
continue
pr_info = self._parse_pr(pr_data)
prs.append(pr_info.to_dict())
return ListPRsResult(
success=True,
pull_requests=prs,
total_count=len(prs),
)
except APIError as e:
return ListPRsResult(
success=False,
error=str(e),
)
async def merge_pr(
self,
owner: str,
repo: str,
pr_number: int,
merge_strategy: MergeStrategy = MergeStrategy.MERGE,
commit_message: str | None = None,
delete_branch: bool = True,
) -> MergePRResult:
"""Merge a pull request."""
try:
# Map merge strategy to GitHub's merge_method values
method_map = {
MergeStrategy.MERGE: "merge",
MergeStrategy.SQUASH: "squash",
MergeStrategy.REBASE: "rebase",
}
data: dict[str, Any] = {
"merge_method": method_map[merge_strategy],
}
if commit_message:
# For squash, commit_title and commit_message
# For merge, commit_title and commit_message
parts = commit_message.split("\n", 1)
data["commit_title"] = parts[0]
if len(parts) > 1:
data["commit_message"] = parts[1]
result = await self._request(
"PUT",
f"/repos/{owner}/{repo}/pulls/{pr_number}/merge",
json=data,
)
if result is None:
return MergePRResult(
success=False,
error="Failed to merge pull request",
)
branch_deleted = False
# Delete branch if requested
if delete_branch and result.get("merged"):
# Get PR to find the branch name
pr_result = await self.get_pr(owner, repo, pr_number)
if pr_result.success and pr_result.pr:
source_branch = pr_result.pr.get("source_branch")
if source_branch:
branch_deleted = await self.delete_remote_branch(
owner, repo, source_branch
)
return MergePRResult(
success=True,
merge_commit_sha=result.get("sha"),
branch_deleted=branch_deleted,
)
except APIError as e:
return MergePRResult(
success=False,
error=str(e),
)
async def update_pr(
self,
owner: str,
repo: str,
pr_number: int,
title: str | None = None,
body: str | None = None,
state: PRState | None = None,
labels: list[str] | None = None,
assignees: list[str] | None = None,
) -> UpdatePRResult:
"""Update a pull request."""
try:
data: dict[str, Any] = {}
if title is not None:
data["title"] = title
if body is not None:
data["body"] = body
if state is not None:
if state == PRState.OPEN:
data["state"] = "open"
elif state == PRState.CLOSED:
data["state"] = "closed"
# Update PR if there's data
if data:
await self._request(
"PATCH",
f"/repos/{owner}/{repo}/pulls/{pr_number}",
json=data,
)
# Update labels via issue endpoint
if labels is not None:
await self._request(
"PUT",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
json={"labels": labels},
)
# Update assignees via issue endpoint
if assignees is not None:
# First remove all assignees
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
json={"assignees": []},
)
# Then add new ones
if assignees:
await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/assignees",
json={"assignees": assignees},
)
# Fetch updated PR
result = await self.get_pr(owner, repo, pr_number)
return UpdatePRResult(
success=result.success,
pr=result.pr,
error=result.error,
)
except APIError as e:
return UpdatePRResult(
success=False,
error=str(e),
)
async def close_pr(
self,
owner: str,
repo: str,
pr_number: int,
) -> UpdatePRResult:
"""Close a pull request without merging."""
return await self.update_pr(
owner,
repo,
pr_number,
state=PRState.CLOSED,
)
# Branch operations
async def delete_remote_branch(
self,
owner: str,
repo: str,
branch: str,
) -> bool:
"""Delete a remote branch."""
try:
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/git/refs/heads/{branch}",
)
return True
except APIError:
return False
async def get_branch(
self,
owner: str,
repo: str,
branch: str,
) -> dict[str, Any] | None:
"""Get branch information."""
return await self._request(
"GET",
f"/repos/{owner}/{repo}/branches/{branch}",
)
# Comment operations
async def add_pr_comment(
self,
owner: str,
repo: str,
pr_number: int,
body: str,
) -> dict[str, Any]:
"""Add a comment to a pull request."""
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
json={"body": body},
)
return result or {}
async def list_pr_comments(
self,
owner: str,
repo: str,
pr_number: int,
) -> list[dict[str, Any]]:
"""List comments on a pull request."""
result = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}/comments",
)
return result or []
# Label operations
async def add_labels(
self,
owner: str,
repo: str,
pr_number: int,
labels: list[str],
) -> list[str]:
"""Add labels to a pull request."""
# GitHub creates labels automatically if they don't exist (unlike Gitea)
result = await self._request(
"POST",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels",
json={"labels": labels},
)
if result:
return [lbl["name"] for lbl in result]
return labels
async def remove_label(
self,
owner: str,
repo: str,
pr_number: int,
label: str,
) -> list[str]:
"""Remove a label from a pull request."""
await self._request(
"DELETE",
f"/repos/{owner}/{repo}/issues/{pr_number}/labels/{label}",
)
# Return remaining labels
issue = await self._request(
"GET",
f"/repos/{owner}/{repo}/issues/{pr_number}",
)
if issue:
return [lbl["name"] for lbl in issue.get("labels", [])]
return []
# Reviewer operations
async def request_review(
self,
owner: str,
repo: str,
pr_number: int,
reviewers: list[str],
) -> list[str]:
"""Request review from users."""
await self._request(
"POST",
f"/repos/{owner}/{repo}/pulls/{pr_number}/requested_reviewers",
json={"reviewers": reviewers},
)
return reviewers
# Helper methods
def _parse_pr(self, data: dict[str, Any]) -> PRInfo:
"""Parse PR API response into PRInfo."""
# Parse dates
created_at = self._parse_datetime(data.get("created_at"))
updated_at = self._parse_datetime(data.get("updated_at"))
merged_at = self._parse_datetime(data.get("merged_at"))
closed_at = self._parse_datetime(data.get("closed_at"))
# Determine state
if data.get("merged_at"):
state = PRState.MERGED
elif data.get("state") == "closed":
state = PRState.CLOSED
else:
state = PRState.OPEN
# Extract labels
labels = [lbl["name"] for lbl in data.get("labels", [])]
# Extract assignees
assignees = [a["login"] for a in data.get("assignees", [])]
# Extract reviewers
reviewers = []
if "requested_reviewers" in data:
reviewers = [r["login"] for r in data["requested_reviewers"]]
return PRInfo(
number=data["number"],
title=data["title"],
body=data.get("body", "") or "",
state=state,
source_branch=data.get("head", {}).get("ref", ""),
target_branch=data.get("base", {}).get("ref", ""),
author=data.get("user", {}).get("login", ""),
created_at=created_at,
updated_at=updated_at,
merged_at=merged_at,
closed_at=closed_at,
url=data.get("html_url"),
labels=labels,
assignees=assignees,
reviewers=reviewers,
mergeable=data.get("mergeable"),
draft=data.get("draft", False),
)
def _parse_datetime(self, value: str | None) -> datetime:
"""Parse datetime string from API."""
if not value:
return datetime.now(UTC)
try:
# GitHub uses ISO 8601 format with Z suffix
if value.endswith("Z"):
value = value[:-1] + "+00:00"
return datetime.fromisoformat(value)
except ValueError:
return datetime.now(UTC)

View File

@@ -0,0 +1,120 @@
[project]
name = "syndarix-mcp-git-ops"
version = "0.1.0"
description = "Syndarix Git Operations MCP Server - Repository management, branching, commits, and PR workflows"
requires-python = ">=3.12"
dependencies = [
"fastmcp>=2.0.0",
"gitpython>=3.1.0",
"httpx>=0.27.0",
"redis>=5.0.0",
"pydantic>=2.0.0",
"pydantic-settings>=2.0.0",
"uvicorn>=0.30.0",
"fastapi>=0.115.0",
"filelock>=3.15.0",
"aiofiles>=24.1.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
"pytest-cov>=5.0.0",
"fakeredis>=2.25.0",
"ruff>=0.8.0",
"mypy>=1.11.0",
"respx>=0.21.0",
]
[project.scripts]
git-ops = "server:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["."]
exclude = ["tests/", "*.md", "Dockerfile"]
[tool.hatch.build.targets.sdist]
include = ["*.py", "pyproject.toml"]
[tool.ruff]
target-version = "py312"
line-length = 88
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"ARG", # flake8-unused-arguments
"SIM", # flake8-simplify
"S", # flake8-bandit (security)
]
ignore = [
"E501", # line too long (handled by formatter)
"B008", # do not perform function calls in argument defaults
"B904", # raise from in except (too noisy)
"S104", # possible binding to all interfaces
"S110", # try-except-pass (intentional for optional operations)
"S603", # subprocess without shell=True (safe usage in git wrapper)
"S607", # starting a process with a partial path (git CLI)
"ARG002", # unused method arguments (for API compatibility)
"SIM102", # nested if statements (sometimes more readable)
"SIM105", # contextlib.suppress (sometimes more readable)
"SIM108", # ternary operator (sometimes more readable)
"SIM118", # dict.keys() (explicit is fine)
]
[tool.ruff.lint.isort]
known-first-party = ["config", "models", "exceptions", "git_wrapper", "workspace", "providers"]
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["S101", "ARG001", "S105", "S106", "S108", "F841", "B007"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
testpaths = ["tests"]
addopts = "-v --tb=short"
filterwarnings = [
"ignore::DeprecationWarning",
]
[tool.coverage.run]
source = ["."]
omit = ["tests/*", "conftest.py"]
branch = true
[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"raise NotImplementedError",
"if TYPE_CHECKING:",
"if __name__ == .__main__.:",
]
fail_under = 78
show_missing = true
[tool.mypy]
python_version = "3.12"
warn_return_any = false
warn_unused_ignores = false
disallow_untyped_defs = true
ignore_missing_imports = true
plugins = ["pydantic.mypy"]
files = ["server.py", "config.py", "models.py", "exceptions.py", "git_wrapper.py", "workspace.py", "providers/"]
exclude = ["tests/"]
[[tool.mypy.overrides]]
module = "tests.*"
disallow_untyped_defs = false
ignore_errors = true

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
"""Tests for Git Operations MCP Server."""

View File

@@ -0,0 +1,299 @@
"""
Test configuration and fixtures for Git Operations MCP Server.
"""
import os
import shutil
import tempfile
from collections.abc import AsyncIterator, Iterator
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from git import Repo as GitRepo
# Set test environment
os.environ["IS_TEST"] = "true"
os.environ["GIT_OPS_WORKSPACE_BASE_PATH"] = "/tmp/test-workspaces"
os.environ["GIT_OPS_GITEA_BASE_URL"] = "https://gitea.test.com"
os.environ["GIT_OPS_GITEA_TOKEN"] = "test-token"
@pytest.fixture(scope="session", autouse=True)
def reset_settings_session():
"""Reset settings at start and end of test session."""
from config import reset_settings
reset_settings()
yield
reset_settings()
@pytest.fixture
def reset_settings():
"""Reset settings before each test that needs it."""
from config import reset_settings
reset_settings()
yield
reset_settings()
@pytest.fixture
def test_settings():
"""Get test settings."""
from config import Settings
return Settings(
workspace_base_path=Path("/tmp/test-workspaces"),
gitea_base_url="https://gitea.test.com",
gitea_token="test-token",
github_token="github-test-token",
git_author_name="Test Agent",
git_author_email="test@syndarix.ai",
enable_force_push=False,
debug=True,
)
@pytest.fixture
def temp_dir() -> Iterator[Path]:
"""Create a temporary directory for tests."""
temp_path = Path(tempfile.mkdtemp())
yield temp_path
if temp_path.exists():
shutil.rmtree(temp_path)
@pytest.fixture
def temp_workspace(temp_dir: Path) -> Path:
"""Create a temporary workspace directory."""
workspace = temp_dir / "workspace"
workspace.mkdir(parents=True, exist_ok=True)
return workspace
@pytest.fixture
def git_repo(temp_workspace: Path) -> GitRepo:
"""Create a git repository in the temp workspace."""
# Initialize with main branch (Git 2.28+)
repo = GitRepo.init(temp_workspace, initial_branch="main")
# Configure git
with repo.config_writer() as cw:
cw.set_value("user", "name", "Test User")
cw.set_value("user", "email", "test@example.com")
# Create initial commit
test_file = temp_workspace / "README.md"
test_file.write_text("# Test Repository\n")
repo.index.add(["README.md"])
repo.index.commit("Initial commit")
return repo
@pytest.fixture
def git_repo_with_remote(git_repo: GitRepo, temp_dir: Path) -> tuple[GitRepo, GitRepo]:
"""Create a git repository with a 'remote' (bare repo)."""
# Create bare repo as remote
remote_path = temp_dir / "remote.git"
remote_repo = GitRepo.init(remote_path, bare=True)
# Add remote to main repo
git_repo.create_remote("origin", str(remote_path))
# Push initial commit
git_repo.remotes.origin.push("main:main")
# Set up tracking
git_repo.heads.main.set_tracking_branch(git_repo.remotes.origin.refs.main)
return git_repo, remote_repo
@pytest.fixture
def workspace_manager(temp_dir: Path, test_settings):
"""Create a WorkspaceManager with test settings."""
from workspace import WorkspaceManager
test_settings.workspace_base_path = temp_dir / "workspaces"
return WorkspaceManager(test_settings)
@pytest.fixture
def git_wrapper(temp_workspace: Path, test_settings):
"""Create a GitWrapper for the temp workspace."""
from git_wrapper import GitWrapper
return GitWrapper(temp_workspace, test_settings)
@pytest.fixture
def git_wrapper_with_repo(git_repo: GitRepo, test_settings):
"""Create a GitWrapper for a repo that's already initialized."""
from git_wrapper import GitWrapper
return GitWrapper(Path(git_repo.working_dir), test_settings)
@pytest.fixture
def mock_gitea_provider():
"""Create a mock Gitea provider."""
provider = AsyncMock()
provider.name = "gitea"
provider.is_connected = AsyncMock(return_value=True)
provider.get_authenticated_user = AsyncMock(return_value="test-user")
provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
return provider
@pytest.fixture
def mock_httpx_client():
"""Create a mock httpx client for provider tests."""
from unittest.mock import AsyncMock
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = MagicMock(return_value={})
mock_response.text = ""
mock_client = AsyncMock()
mock_client.request = AsyncMock(return_value=mock_response)
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.patch = AsyncMock(return_value=mock_response)
mock_client.delete = AsyncMock(return_value=mock_response)
return mock_client
@pytest.fixture
async def gitea_provider(test_settings, mock_httpx_client):
"""Create a GiteaProvider with mocked HTTP client."""
from providers.gitea import GiteaProvider
provider = GiteaProvider(
base_url=test_settings.gitea_base_url,
token=test_settings.gitea_token,
settings=test_settings,
)
provider._client = mock_httpx_client
yield provider
await provider.close()
@pytest.fixture
def sample_pr_data():
"""Sample PR data from Gitea API."""
return {
"number": 42,
"title": "Test PR",
"body": "This is a test pull request",
"state": "open",
"head": {"ref": "feature-branch"},
"base": {"ref": "main"},
"user": {"login": "test-user"},
"created_at": "2024-01-15T10:00:00Z",
"updated_at": "2024-01-15T12:00:00Z",
"merged_at": None,
"closed_at": None,
"html_url": "https://gitea.test.com/owner/repo/pull/42",
"labels": [{"name": "enhancement"}],
"assignees": [{"login": "assignee1"}],
"requested_reviewers": [{"login": "reviewer1"}],
"mergeable": True,
"draft": False,
}
@pytest.fixture
def sample_commit_data():
"""Sample commit data."""
return {
"sha": "abc123def456",
"short_sha": "abc123d",
"message": "Test commit message",
"author": {
"name": "Test Author",
"email": "author@test.com",
"date": "2024-01-15T10:00:00Z",
},
"committer": {
"name": "Test Committer",
"email": "committer@test.com",
"date": "2024-01-15T10:00:00Z",
},
}
@pytest.fixture
def mock_fastapi_app():
"""Create a test FastAPI app."""
from fastapi import FastAPI
from fastapi.testclient import TestClient
app = FastAPI()
@app.get("/health")
def health():
return {"status": "healthy"}
return TestClient(app)
# Async fixtures
@pytest.fixture
async def async_workspace_manager(
temp_dir: Path, test_settings
) -> AsyncIterator:
"""Async fixture for workspace manager."""
from workspace import WorkspaceManager
test_settings.workspace_base_path = temp_dir / "workspaces"
manager = WorkspaceManager(test_settings)
yield manager
# Test data fixtures
@pytest.fixture
def valid_project_id() -> str:
"""Valid project ID for tests."""
return "test-project-123"
@pytest.fixture
def valid_agent_id() -> str:
"""Valid agent ID for tests."""
return "agent-456"
@pytest.fixture
def invalid_ids() -> list[str]:
"""Invalid IDs for validation tests."""
return [
"",
" ",
"a" * 200, # Too long
"test@invalid", # Invalid character
"test!invalid",
"../path/traversal",
]
@pytest.fixture
def sample_repo_url() -> str:
"""Sample repository URL."""
return "https://gitea.test.com/owner/repo.git"
@pytest.fixture
def sample_ssh_repo_url() -> str:
"""Sample SSH repository URL."""
return "git@gitea.test.com:owner/repo.git"

View File

@@ -0,0 +1,434 @@
"""
Tests for the git_wrapper module.
"""
from pathlib import Path
import pytest
from exceptions import (
BranchExistsError,
BranchNotFoundError,
CheckoutError,
CommitError,
GitError,
)
from git_wrapper import GitWrapper
from models import FileChangeType
class TestGitWrapperInit:
"""Tests for GitWrapper initialization."""
def test_init_with_valid_path(self, temp_workspace, test_settings):
"""Test initialization with a valid path."""
wrapper = GitWrapper(temp_workspace, test_settings)
assert wrapper.workspace_path == temp_workspace
assert wrapper.settings == test_settings
def test_repo_property_raises_on_non_git(self, temp_workspace, test_settings):
"""Test that accessing repo on non-git dir raises error."""
wrapper = GitWrapper(temp_workspace, test_settings)
with pytest.raises(GitError, match="Not a git repository"):
_ = wrapper.repo
def test_repo_property_works_on_git_dir(self, git_repo, test_settings):
"""Test that repo property works for git directory."""
wrapper = GitWrapper(Path(git_repo.working_dir), test_settings)
assert wrapper.repo is not None
assert wrapper.repo.head is not None
class TestGitWrapperStatus:
"""Tests for git status operations."""
@pytest.mark.asyncio
async def test_status_clean_repo(self, git_wrapper_with_repo):
"""Test status on a clean repository."""
result = await git_wrapper_with_repo.status()
assert result.branch == "main"
assert result.is_clean is True
assert len(result.staged) == 0
assert len(result.unstaged) == 0
assert len(result.untracked) == 0
@pytest.mark.asyncio
async def test_status_with_untracked(self, git_wrapper_with_repo, git_repo):
"""Test status with untracked files."""
# Create untracked file
untracked_file = Path(git_repo.working_dir) / "untracked.txt"
untracked_file.write_text("untracked content")
result = await git_wrapper_with_repo.status()
assert result.is_clean is False
assert "untracked.txt" in result.untracked
@pytest.mark.asyncio
async def test_status_with_modified(self, git_wrapper_with_repo, git_repo):
"""Test status with modified files."""
# Modify existing file
readme = Path(git_repo.working_dir) / "README.md"
readme.write_text("# Modified content\n")
result = await git_wrapper_with_repo.status()
assert result.is_clean is False
assert len(result.unstaged) > 0
@pytest.mark.asyncio
async def test_status_with_staged(self, git_wrapper_with_repo, git_repo):
"""Test status with staged changes."""
# Create and stage a file
new_file = Path(git_repo.working_dir) / "staged.txt"
new_file.write_text("staged content")
git_repo.index.add(["staged.txt"])
result = await git_wrapper_with_repo.status()
assert result.is_clean is False
assert len(result.staged) > 0
@pytest.mark.asyncio
async def test_status_exclude_untracked(self, git_wrapper_with_repo, git_repo):
"""Test status without untracked files."""
untracked_file = Path(git_repo.working_dir) / "untracked.txt"
untracked_file.write_text("untracked")
result = await git_wrapper_with_repo.status(include_untracked=False)
assert len(result.untracked) == 0
class TestGitWrapperBranch:
"""Tests for branch operations."""
@pytest.mark.asyncio
async def test_create_branch(self, git_wrapper_with_repo):
"""Test creating a new branch."""
result = await git_wrapper_with_repo.create_branch("feature-test")
assert result.success is True
assert result.branch == "feature-test"
assert result.is_current is True
@pytest.mark.asyncio
async def test_create_branch_without_checkout(self, git_wrapper_with_repo):
"""Test creating branch without checkout."""
result = await git_wrapper_with_repo.create_branch("feature-no-checkout", checkout=False)
assert result.success is True
assert result.branch == "feature-no-checkout"
assert result.is_current is False
@pytest.mark.asyncio
async def test_create_branch_exists_error(self, git_wrapper_with_repo):
"""Test error when branch already exists."""
await git_wrapper_with_repo.create_branch("existing-branch", checkout=False)
with pytest.raises(BranchExistsError):
await git_wrapper_with_repo.create_branch("existing-branch")
@pytest.mark.asyncio
async def test_delete_branch(self, git_wrapper_with_repo):
"""Test deleting a branch."""
# Create branch first
await git_wrapper_with_repo.create_branch("to-delete", checkout=False)
# Delete it
result = await git_wrapper_with_repo.delete_branch("to-delete")
assert result.success is True
assert result.branch == "to-delete"
@pytest.mark.asyncio
async def test_delete_branch_not_found(self, git_wrapper_with_repo):
"""Test error when deleting non-existent branch."""
with pytest.raises(BranchNotFoundError):
await git_wrapper_with_repo.delete_branch("nonexistent")
@pytest.mark.asyncio
async def test_delete_current_branch_error(self, git_wrapper_with_repo):
"""Test error when deleting current branch."""
with pytest.raises(GitError, match="Cannot delete current branch"):
await git_wrapper_with_repo.delete_branch("main")
@pytest.mark.asyncio
async def test_list_branches(self, git_wrapper_with_repo):
"""Test listing branches."""
# Create some branches
await git_wrapper_with_repo.create_branch("branch-a", checkout=False)
await git_wrapper_with_repo.create_branch("branch-b", checkout=False)
result = await git_wrapper_with_repo.list_branches()
assert result.current_branch == "main"
branch_names = [b["name"] for b in result.local_branches]
assert "main" in branch_names
assert "branch-a" in branch_names
assert "branch-b" in branch_names
class TestGitWrapperCheckout:
"""Tests for checkout operations."""
@pytest.mark.asyncio
async def test_checkout_existing_branch(self, git_wrapper_with_repo):
"""Test checkout of existing branch."""
# Create branch first
await git_wrapper_with_repo.create_branch("test-branch", checkout=False)
result = await git_wrapper_with_repo.checkout("test-branch")
assert result.success is True
assert result.ref == "test-branch"
@pytest.mark.asyncio
async def test_checkout_create_new(self, git_wrapper_with_repo):
"""Test checkout with branch creation."""
result = await git_wrapper_with_repo.checkout("new-branch", create_branch=True)
assert result.success is True
assert result.ref == "new-branch"
@pytest.mark.asyncio
async def test_checkout_nonexistent_error(self, git_wrapper_with_repo):
"""Test error when checking out non-existent ref."""
with pytest.raises(CheckoutError):
await git_wrapper_with_repo.checkout("nonexistent-branch")
class TestGitWrapperCommit:
"""Tests for commit operations."""
@pytest.mark.asyncio
async def test_commit_staged_changes(self, git_wrapper_with_repo, git_repo):
"""Test committing staged changes."""
# Create and stage a file
new_file = Path(git_repo.working_dir) / "newfile.txt"
new_file.write_text("new content")
git_repo.index.add(["newfile.txt"])
result = await git_wrapper_with_repo.commit("Add new file")
assert result.success is True
assert result.message == "Add new file"
assert result.files_changed == 1
@pytest.mark.asyncio
async def test_commit_all_changes(self, git_wrapper_with_repo, git_repo):
"""Test committing all changes (auto-stage)."""
# Create a file without staging
new_file = Path(git_repo.working_dir) / "unstaged.txt"
new_file.write_text("content")
result = await git_wrapper_with_repo.commit("Commit unstaged")
assert result.success is True
@pytest.mark.asyncio
async def test_commit_nothing_to_commit(self, git_wrapper_with_repo):
"""Test error when nothing to commit."""
with pytest.raises(CommitError, match="Nothing to commit"):
await git_wrapper_with_repo.commit("Empty commit")
@pytest.mark.asyncio
async def test_commit_with_author(self, git_wrapper_with_repo, git_repo):
"""Test commit with custom author."""
new_file = Path(git_repo.working_dir) / "authored.txt"
new_file.write_text("authored content")
result = await git_wrapper_with_repo.commit(
"Custom author commit",
author_name="Custom Author",
author_email="custom@test.com",
)
assert result.success is True
class TestGitWrapperDiff:
"""Tests for diff operations."""
@pytest.mark.asyncio
async def test_diff_no_changes(self, git_wrapper_with_repo):
"""Test diff with no changes."""
result = await git_wrapper_with_repo.diff()
assert result.files_changed == 0
assert result.total_additions == 0
assert result.total_deletions == 0
@pytest.mark.asyncio
async def test_diff_with_changes(self, git_wrapper_with_repo, git_repo):
"""Test diff with modified files."""
# Modify a file
readme = Path(git_repo.working_dir) / "README.md"
readme.write_text("# Modified\nNew line\n")
result = await git_wrapper_with_repo.diff()
assert result.files_changed > 0
class TestGitWrapperLog:
"""Tests for log operations."""
@pytest.mark.asyncio
async def test_log_basic(self, git_wrapper_with_repo):
"""Test basic log."""
result = await git_wrapper_with_repo.log()
assert result.total_commits > 0
assert len(result.commits) > 0
@pytest.mark.asyncio
async def test_log_with_limit(self, git_wrapper_with_repo, git_repo):
"""Test log with limit."""
# Create more commits
for i in range(5):
file_path = Path(git_repo.working_dir) / f"file{i}.txt"
file_path.write_text(f"content {i}")
git_repo.index.add([f"file{i}.txt"])
git_repo.index.commit(f"Commit {i}")
result = await git_wrapper_with_repo.log(limit=3)
assert len(result.commits) == 3
@pytest.mark.asyncio
async def test_log_commit_info(self, git_wrapper_with_repo):
"""Test that log returns proper commit info."""
result = await git_wrapper_with_repo.log(limit=1)
commit = result.commits[0]
assert "sha" in commit
assert "message" in commit
assert "author_name" in commit
assert "author_email" in commit
class TestGitWrapperUtilities:
"""Tests for utility methods."""
@pytest.mark.asyncio
async def test_is_valid_ref_true(self, git_wrapper_with_repo):
"""Test valid ref detection."""
is_valid = await git_wrapper_with_repo.is_valid_ref("main")
assert is_valid is True
@pytest.mark.asyncio
async def test_is_valid_ref_false(self, git_wrapper_with_repo):
"""Test invalid ref detection."""
is_valid = await git_wrapper_with_repo.is_valid_ref("nonexistent")
assert is_valid is False
def test_diff_to_change_type(self, git_wrapper_with_repo):
"""Test change type conversion."""
wrapper = git_wrapper_with_repo
assert wrapper._diff_to_change_type("A") == FileChangeType.ADDED
assert wrapper._diff_to_change_type("M") == FileChangeType.MODIFIED
assert wrapper._diff_to_change_type("D") == FileChangeType.DELETED
assert wrapper._diff_to_change_type("R") == FileChangeType.RENAMED
class TestGitWrapperStage:
"""Tests for staging operations."""
@pytest.mark.asyncio
async def test_stage_specific_files(self, git_wrapper_with_repo, git_repo):
"""Test staging specific files."""
# Create files
file1 = Path(git_repo.working_dir) / "file1.txt"
file2 = Path(git_repo.working_dir) / "file2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
count = await git_wrapper_with_repo.stage(["file1.txt"])
assert count == 1
@pytest.mark.asyncio
async def test_stage_all(self, git_wrapper_with_repo, git_repo):
"""Test staging all files."""
file1 = Path(git_repo.working_dir) / "all1.txt"
file2 = Path(git_repo.working_dir) / "all2.txt"
file1.write_text("content 1")
file2.write_text("content 2")
count = await git_wrapper_with_repo.stage()
assert count >= 2
@pytest.mark.asyncio
async def test_unstage_files(self, git_wrapper_with_repo, git_repo):
"""Test unstaging files."""
# Create and stage file
file1 = Path(git_repo.working_dir) / "unstage.txt"
file1.write_text("to unstage")
git_repo.index.add(["unstage.txt"])
count = await git_wrapper_with_repo.unstage()
assert count >= 1
class TestGitWrapperReset:
"""Tests for reset operations."""
@pytest.mark.asyncio
async def test_reset_soft(self, git_wrapper_with_repo, git_repo):
"""Test soft reset."""
# Create a commit to reset
file1 = Path(git_repo.working_dir) / "reset_soft.txt"
file1.write_text("content")
git_repo.index.add(["reset_soft.txt"])
git_repo.index.commit("Commit to reset")
result = await git_wrapper_with_repo.reset("HEAD~1", mode="soft")
assert result is True
@pytest.mark.asyncio
async def test_reset_mixed(self, git_wrapper_with_repo, git_repo):
"""Test mixed reset (default)."""
file1 = Path(git_repo.working_dir) / "reset_mixed.txt"
file1.write_text("content")
git_repo.index.add(["reset_mixed.txt"])
git_repo.index.commit("Commit to reset")
result = await git_wrapper_with_repo.reset("HEAD~1", mode="mixed")
assert result is True
@pytest.mark.asyncio
async def test_reset_invalid_mode(self, git_wrapper_with_repo):
"""Test error on invalid reset mode."""
with pytest.raises(GitError, match="Invalid reset mode"):
await git_wrapper_with_repo.reset("HEAD", mode="invalid")
class TestGitWrapperStash:
"""Tests for stash operations."""
@pytest.mark.asyncio
async def test_stash_changes(self, git_wrapper_with_repo, git_repo):
"""Test stashing changes."""
# Make changes
readme = Path(git_repo.working_dir) / "README.md"
readme.write_text("Modified for stash")
result = await git_wrapper_with_repo.stash("Test stash")
# Result should be stash ref or None if nothing to stash
# (depends on whether changes were already staged)
assert result is None or result.startswith("stash@")
@pytest.mark.asyncio
async def test_stash_nothing(self, git_wrapper_with_repo):
"""Test stash with no changes."""
result = await git_wrapper_with_repo.stash()
assert result is None

View File

@@ -0,0 +1,583 @@
"""
Tests for GitHub provider implementation.
"""
from unittest.mock import MagicMock
import pytest
from exceptions import APIError, AuthenticationError
from models import MergeStrategy, PRState
from providers.github import GitHubProvider
class TestGitHubProviderBasics:
"""Tests for GitHubProvider basic operations."""
def test_provider_name(self):
"""Test provider name is github."""
provider = GitHubProvider(token="test-token")
assert provider.name == "github"
def test_parse_repo_url_https(self):
"""Test parsing HTTPS repo URL."""
provider = GitHubProvider(token="test-token")
owner, repo = provider.parse_repo_url("https://github.com/owner/repo.git")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_https_no_git(self):
"""Test parsing HTTPS URL without .git suffix."""
provider = GitHubProvider(token="test-token")
owner, repo = provider.parse_repo_url("https://github.com/owner/repo")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_ssh(self):
"""Test parsing SSH repo URL."""
provider = GitHubProvider(token="test-token")
owner, repo = provider.parse_repo_url("git@github.com:owner/repo.git")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_invalid(self):
"""Test error on invalid URL."""
provider = GitHubProvider(token="test-token")
with pytest.raises(ValueError, match="Unable to parse"):
provider.parse_repo_url("invalid-url")
@pytest.fixture
def mock_github_httpx_client():
"""Create a mock httpx client for GitHub provider tests."""
from unittest.mock import AsyncMock
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.json = MagicMock(return_value={})
mock_response.text = ""
mock_client = AsyncMock()
mock_client.request = AsyncMock(return_value=mock_response)
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.post = AsyncMock(return_value=mock_response)
mock_client.patch = AsyncMock(return_value=mock_response)
mock_client.put = AsyncMock(return_value=mock_response)
mock_client.delete = AsyncMock(return_value=mock_response)
return mock_client
@pytest.fixture
async def github_provider(test_settings, mock_github_httpx_client):
"""Create a GitHubProvider with mocked HTTP client."""
provider = GitHubProvider(
token=test_settings.github_token,
settings=test_settings,
)
provider._client = mock_github_httpx_client
yield provider
await provider.close()
@pytest.fixture
def github_pr_data():
"""Sample PR data from GitHub API."""
return {
"number": 42,
"title": "Test PR",
"body": "This is a test pull request",
"state": "open",
"head": {"ref": "feature-branch"},
"base": {"ref": "main"},
"user": {"login": "test-user"},
"created_at": "2024-01-15T10:00:00Z",
"updated_at": "2024-01-15T12:00:00Z",
"merged_at": None,
"closed_at": None,
"html_url": "https://github.com/owner/repo/pull/42",
"labels": [{"name": "enhancement"}],
"assignees": [{"login": "assignee1"}],
"requested_reviewers": [{"login": "reviewer1"}],
"mergeable": True,
"draft": False,
}
class TestGitHubProviderConnection:
"""Tests for GitHub provider connection."""
@pytest.mark.asyncio
async def test_is_connected(self, github_provider, mock_github_httpx_client):
"""Test connection check."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"login": "test-user"}
)
result = await github_provider.is_connected()
assert result is True
@pytest.mark.asyncio
async def test_is_connected_no_token(self, test_settings):
"""Test connection fails without token."""
provider = GitHubProvider(
token="",
settings=test_settings,
)
result = await provider.is_connected()
assert result is False
await provider.close()
@pytest.mark.asyncio
async def test_get_authenticated_user(self, github_provider, mock_github_httpx_client):
"""Test getting authenticated user."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"login": "test-user"}
)
user = await github_provider.get_authenticated_user()
assert user == "test-user"
class TestGitHubProviderRepoOperations:
"""Tests for GitHub repository operations."""
@pytest.mark.asyncio
async def test_get_repo_info(self, github_provider, mock_github_httpx_client):
"""Test getting repository info."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"name": "repo",
"full_name": "owner/repo",
"default_branch": "main",
}
)
result = await github_provider.get_repo_info("owner", "repo")
assert result["name"] == "repo"
assert result["default_branch"] == "main"
@pytest.mark.asyncio
async def test_get_default_branch(self, github_provider, mock_github_httpx_client):
"""Test getting default branch."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"default_branch": "develop"}
)
branch = await github_provider.get_default_branch("owner", "repo")
assert branch == "develop"
class TestGitHubPROperations:
"""Tests for GitHub PR operations."""
@pytest.mark.asyncio
async def test_create_pr(self, github_provider, mock_github_httpx_client):
"""Test creating a pull request."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"number": 42,
"html_url": "https://github.com/owner/repo/pull/42",
}
)
result = await github_provider.create_pr(
owner="owner",
repo="repo",
title="Test PR",
body="Test body",
source_branch="feature",
target_branch="main",
)
assert result.success is True
assert result.pr_number == 42
assert result.pr_url == "https://github.com/owner/repo/pull/42"
@pytest.mark.asyncio
async def test_create_pr_with_draft(self, github_provider, mock_github_httpx_client):
"""Test creating a draft PR."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"number": 43,
"html_url": "https://github.com/owner/repo/pull/43",
}
)
result = await github_provider.create_pr(
owner="owner",
repo="repo",
title="Draft PR",
body="Draft body",
source_branch="feature",
target_branch="main",
draft=True,
)
assert result.success is True
assert result.pr_number == 43
@pytest.mark.asyncio
async def test_create_pr_with_options(self, github_provider, mock_github_httpx_client):
"""Test creating PR with labels, assignees, reviewers."""
mock_responses = [
{"number": 44, "html_url": "https://github.com/owner/repo/pull/44"}, # Create PR
[{"name": "enhancement"}], # POST add labels
{}, # POST add assignees
{}, # POST request reviewers
]
mock_github_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
result = await github_provider.create_pr(
owner="owner",
repo="repo",
title="Test PR",
body="Test body",
source_branch="feature",
target_branch="main",
labels=["enhancement"],
assignees=["user1"],
reviewers=["reviewer1"],
)
assert result.success is True
@pytest.mark.asyncio
async def test_get_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test getting a pull request."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=github_pr_data
)
result = await github_provider.get_pr("owner", "repo", 42)
assert result.success is True
assert result.pr["number"] == 42
assert result.pr["title"] == "Test PR"
@pytest.mark.asyncio
async def test_get_pr_not_found(self, github_provider, mock_github_httpx_client):
"""Test getting non-existent PR."""
mock_github_httpx_client.request.return_value.status_code = 404
mock_github_httpx_client.request.return_value.json = MagicMock(return_value=None)
result = await github_provider.get_pr("owner", "repo", 999)
assert result.success is False
@pytest.mark.asyncio
async def test_list_prs(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test listing pull requests."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[github_pr_data, github_pr_data]
)
result = await github_provider.list_prs("owner", "repo")
assert result.success is True
assert len(result.pull_requests) == 2
@pytest.mark.asyncio
async def test_list_prs_with_state_filter(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test listing PRs with state filter."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[github_pr_data]
)
result = await github_provider.list_prs(
"owner", "repo", state=PRState.OPEN
)
assert result.success is True
@pytest.mark.asyncio
async def test_merge_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test merging a pull request."""
# Merge returns sha, then get_pr returns the PR data, then delete branch
mock_responses = [
{"sha": "merge-commit-sha", "merged": True}, # PUT merge
github_pr_data, # GET PR for branch info
None, # DELETE branch
]
mock_github_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await github_provider.merge_pr(
"owner", "repo", 42,
merge_strategy=MergeStrategy.SQUASH,
)
assert result.success is True
assert result.merge_commit_sha == "merge-commit-sha"
@pytest.mark.asyncio
async def test_merge_pr_rebase(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test merging with rebase strategy."""
mock_responses = [
{"sha": "rebase-commit-sha", "merged": True}, # PUT merge
github_pr_data, # GET PR for branch info
None, # DELETE branch
]
mock_github_httpx_client.request.return_value.json = MagicMock(
side_effect=mock_responses
)
result = await github_provider.merge_pr(
"owner", "repo", 42,
merge_strategy=MergeStrategy.REBASE,
)
assert result.success is True
@pytest.mark.asyncio
async def test_update_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test updating a pull request."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=github_pr_data
)
result = await github_provider.update_pr(
"owner", "repo", 42,
title="Updated Title",
body="Updated body",
)
assert result.success is True
@pytest.mark.asyncio
async def test_close_pr(self, github_provider, mock_github_httpx_client, github_pr_data):
"""Test closing a pull request."""
github_pr_data["state"] = "closed"
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=github_pr_data
)
result = await github_provider.close_pr("owner", "repo", 42)
assert result.success is True
class TestGitHubBranchOperations:
"""Tests for GitHub branch operations."""
@pytest.mark.asyncio
async def test_get_branch(self, github_provider, mock_github_httpx_client):
"""Test getting branch info."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={
"name": "main",
"commit": {"sha": "abc123"},
}
)
result = await github_provider.get_branch("owner", "repo", "main")
assert result["name"] == "main"
@pytest.mark.asyncio
async def test_delete_remote_branch(self, github_provider, mock_github_httpx_client):
"""Test deleting a remote branch."""
mock_github_httpx_client.request.return_value.status_code = 204
result = await github_provider.delete_remote_branch("owner", "repo", "old-branch")
assert result is True
class TestGitHubCommentOperations:
"""Tests for GitHub comment operations."""
@pytest.mark.asyncio
async def test_add_pr_comment(self, github_provider, mock_github_httpx_client):
"""Test adding a comment to a PR."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"id": 1, "body": "Test comment"}
)
result = await github_provider.add_pr_comment(
"owner", "repo", 42, "Test comment"
)
assert result["body"] == "Test comment"
@pytest.mark.asyncio
async def test_list_pr_comments(self, github_provider, mock_github_httpx_client):
"""Test listing PR comments."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[
{"id": 1, "body": "Comment 1"},
{"id": 2, "body": "Comment 2"},
]
)
result = await github_provider.list_pr_comments("owner", "repo", 42)
assert len(result) == 2
class TestGitHubLabelOperations:
"""Tests for GitHub label operations."""
@pytest.mark.asyncio
async def test_add_labels(self, github_provider, mock_github_httpx_client):
"""Test adding labels to a PR."""
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value=[{"name": "bug"}, {"name": "urgent"}]
)
result = await github_provider.add_labels(
"owner", "repo", 42, ["bug", "urgent"]
)
assert "bug" in result
assert "urgent" in result
@pytest.mark.asyncio
async def test_remove_label(self, github_provider, mock_github_httpx_client):
"""Test removing a label from a PR."""
mock_responses = [
None, # DELETE label
{"labels": []}, # GET issue
]
mock_github_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
result = await github_provider.remove_label(
"owner", "repo", 42, "bug"
)
assert isinstance(result, list)
class TestGitHubReviewerOperations:
"""Tests for GitHub reviewer operations."""
@pytest.mark.asyncio
async def test_request_review(self, github_provider, mock_github_httpx_client):
"""Test requesting review from users."""
mock_github_httpx_client.request.return_value.json = MagicMock(return_value={})
result = await github_provider.request_review(
"owner", "repo", 42, ["reviewer1", "reviewer2"]
)
assert result == ["reviewer1", "reviewer2"]
class TestGitHubErrorHandling:
"""Tests for error handling in GitHub provider."""
@pytest.mark.asyncio
async def test_authentication_error(self, github_provider, mock_github_httpx_client):
"""Test handling authentication errors."""
mock_github_httpx_client.request.return_value.status_code = 401
with pytest.raises(AuthenticationError):
await github_provider._request("GET", "/user")
@pytest.mark.asyncio
async def test_permission_denied(self, github_provider, mock_github_httpx_client):
"""Test handling permission denied errors."""
mock_github_httpx_client.request.return_value.status_code = 403
mock_github_httpx_client.request.return_value.text = "Permission denied"
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
await github_provider._request("GET", "/protected")
@pytest.mark.asyncio
async def test_rate_limit_error(self, github_provider, mock_github_httpx_client):
"""Test handling rate limit errors."""
mock_github_httpx_client.request.return_value.status_code = 403
mock_github_httpx_client.request.return_value.text = "API rate limit exceeded"
with pytest.raises(APIError, match="rate limit"):
await github_provider._request("GET", "/user")
@pytest.mark.asyncio
async def test_api_error(self, github_provider, mock_github_httpx_client):
"""Test handling general API errors."""
mock_github_httpx_client.request.return_value.status_code = 500
mock_github_httpx_client.request.return_value.text = "Internal Server Error"
mock_github_httpx_client.request.return_value.json = MagicMock(
return_value={"message": "Server error"}
)
with pytest.raises(APIError):
await github_provider._request("GET", "/error")
class TestGitHubPRParsing:
"""Tests for PR data parsing."""
def test_parse_pr_open(self, github_provider, github_pr_data):
"""Test parsing open PR."""
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.number == 42
assert pr_info.state == PRState.OPEN
assert pr_info.title == "Test PR"
assert pr_info.source_branch == "feature-branch"
assert pr_info.target_branch == "main"
def test_parse_pr_merged(self, github_provider, github_pr_data):
"""Test parsing merged PR."""
github_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.state == PRState.MERGED
def test_parse_pr_closed(self, github_provider, github_pr_data):
"""Test parsing closed PR."""
github_pr_data["state"] = "closed"
github_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.state == PRState.CLOSED
def test_parse_pr_draft(self, github_provider, github_pr_data):
"""Test parsing draft PR."""
github_pr_data["draft"] = True
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.draft is True
def test_parse_datetime_iso(self, github_provider):
"""Test parsing ISO datetime strings."""
dt = github_provider._parse_datetime("2024-01-15T10:30:00Z")
assert dt.year == 2024
assert dt.month == 1
assert dt.day == 15
def test_parse_datetime_none(self, github_provider):
"""Test parsing None datetime returns now."""
dt = github_provider._parse_datetime(None)
assert dt is not None
assert dt.tzinfo is not None
def test_parse_pr_with_null_body(self, github_provider, github_pr_data):
"""Test parsing PR with null body."""
github_pr_data["body"] = None
pr_info = github_provider._parse_pr(github_pr_data)
assert pr_info.body == ""

View File

@@ -0,0 +1,484 @@
"""
Tests for git provider implementations.
"""
from unittest.mock import MagicMock
import pytest
from exceptions import APIError, AuthenticationError
from models import MergeStrategy, PRState
from providers.gitea import GiteaProvider
class TestBaseProvider:
"""Tests for BaseProvider interface."""
def test_parse_repo_url_https(self, mock_gitea_provider):
"""Test parsing HTTPS repo URL."""
# The mock needs parse_repo_url to work
provider = GiteaProvider(
base_url="https://gitea.test.com",
token="test-token"
)
owner, repo = provider.parse_repo_url("https://gitea.test.com/owner/repo.git")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_https_no_git(self):
"""Test parsing HTTPS URL without .git suffix."""
provider = GiteaProvider(
base_url="https://gitea.test.com",
token="test-token"
)
owner, repo = provider.parse_repo_url("https://gitea.test.com/owner/repo")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_ssh(self):
"""Test parsing SSH repo URL."""
provider = GiteaProvider(
base_url="https://gitea.test.com",
token="test-token"
)
owner, repo = provider.parse_repo_url("git@gitea.test.com:owner/repo.git")
assert owner == "owner"
assert repo == "repo"
def test_parse_repo_url_invalid(self):
"""Test error on invalid URL."""
provider = GiteaProvider(
base_url="https://gitea.test.com",
token="test-token"
)
with pytest.raises(ValueError, match="Unable to parse"):
provider.parse_repo_url("invalid-url")
class TestGiteaProvider:
"""Tests for GiteaProvider."""
@pytest.mark.asyncio
async def test_is_connected(self, gitea_provider, mock_httpx_client):
"""Test connection check."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"login": "test-user"}
)
result = await gitea_provider.is_connected()
assert result is True
@pytest.mark.asyncio
async def test_is_connected_no_token(self, test_settings):
"""Test connection fails without token."""
provider = GiteaProvider(
base_url="https://gitea.test.com",
token="",
settings=test_settings,
)
result = await provider.is_connected()
assert result is False
await provider.close()
@pytest.mark.asyncio
async def test_get_authenticated_user(self, gitea_provider, mock_httpx_client):
"""Test getting authenticated user."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"login": "test-user"}
)
user = await gitea_provider.get_authenticated_user()
assert user == "test-user"
@pytest.mark.asyncio
async def test_get_repo_info(self, gitea_provider, mock_httpx_client):
"""Test getting repository info."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={
"name": "repo",
"full_name": "owner/repo",
"default_branch": "main",
}
)
result = await gitea_provider.get_repo_info("owner", "repo")
assert result["name"] == "repo"
assert result["default_branch"] == "main"
@pytest.mark.asyncio
async def test_get_default_branch(self, gitea_provider, mock_httpx_client):
"""Test getting default branch."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"default_branch": "develop"}
)
branch = await gitea_provider.get_default_branch("owner", "repo")
assert branch == "develop"
class TestGiteaPROperations:
"""Tests for Gitea PR operations."""
@pytest.mark.asyncio
async def test_create_pr(self, gitea_provider, mock_httpx_client):
"""Test creating a pull request."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={
"number": 42,
"html_url": "https://gitea.test.com/owner/repo/pull/42",
}
)
result = await gitea_provider.create_pr(
owner="owner",
repo="repo",
title="Test PR",
body="Test body",
source_branch="feature",
target_branch="main",
)
assert result.success is True
assert result.pr_number == 42
assert result.pr_url == "https://gitea.test.com/owner/repo/pull/42"
@pytest.mark.asyncio
async def test_create_pr_with_options(self, gitea_provider, mock_httpx_client):
"""Test creating PR with labels, assignees, reviewers."""
# Use side_effect for multiple API calls:
# 1. POST create PR
# 2. GET labels (for "enhancement") - in add_labels -> _get_or_create_label
# 3. POST add labels to PR - in add_labels
# 4. GET issue to return labels - in add_labels
# 5. PATCH add assignees
# 6. POST request reviewers
mock_responses = [
{"number": 43, "html_url": "https://gitea.test.com/owner/repo/pull/43"}, # Create PR
[{"id": 1, "name": "enhancement"}], # GET labels (found)
{}, # POST add labels to PR
{"labels": [{"name": "enhancement"}]}, # GET issue to return current labels
{}, # PATCH add assignees
{}, # POST request reviewers
]
mock_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
result = await gitea_provider.create_pr(
owner="owner",
repo="repo",
title="Test PR",
body="Test body",
source_branch="feature",
target_branch="main",
labels=["enhancement"],
assignees=["user1"],
reviewers=["reviewer1"],
)
assert result.success is True
@pytest.mark.asyncio
async def test_get_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
"""Test getting a pull request."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value=sample_pr_data
)
result = await gitea_provider.get_pr("owner", "repo", 42)
assert result.success is True
assert result.pr["number"] == 42
assert result.pr["title"] == "Test PR"
@pytest.mark.asyncio
async def test_get_pr_not_found(self, gitea_provider, mock_httpx_client):
"""Test getting non-existent PR."""
mock_httpx_client.request.return_value.status_code = 404
mock_httpx_client.request.return_value.json = MagicMock(return_value=None)
result = await gitea_provider.get_pr("owner", "repo", 999)
assert result.success is False
@pytest.mark.asyncio
async def test_list_prs(self, gitea_provider, mock_httpx_client, sample_pr_data):
"""Test listing pull requests."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value=[sample_pr_data, sample_pr_data]
)
result = await gitea_provider.list_prs("owner", "repo")
assert result.success is True
assert len(result.pull_requests) == 2
@pytest.mark.asyncio
async def test_list_prs_with_state_filter(self, gitea_provider, mock_httpx_client, sample_pr_data):
"""Test listing PRs with state filter."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value=[sample_pr_data]
)
result = await gitea_provider.list_prs(
"owner", "repo", state=PRState.OPEN
)
assert result.success is True
@pytest.mark.asyncio
async def test_merge_pr(self, gitea_provider, mock_httpx_client):
"""Test merging a pull request."""
# First call returns merge result
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"sha": "merge-commit-sha"}
)
result = await gitea_provider.merge_pr(
"owner", "repo", 42,
merge_strategy=MergeStrategy.SQUASH,
)
assert result.success is True
assert result.merge_commit_sha == "merge-commit-sha"
@pytest.mark.asyncio
async def test_update_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
"""Test updating a pull request."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value=sample_pr_data
)
result = await gitea_provider.update_pr(
"owner", "repo", 42,
title="Updated Title",
body="Updated body",
)
assert result.success is True
@pytest.mark.asyncio
async def test_close_pr(self, gitea_provider, mock_httpx_client, sample_pr_data):
"""Test closing a pull request."""
sample_pr_data["state"] = "closed"
mock_httpx_client.request.return_value.json = MagicMock(
return_value=sample_pr_data
)
result = await gitea_provider.close_pr("owner", "repo", 42)
assert result.success is True
class TestGiteaBranchOperations:
"""Tests for Gitea branch operations."""
@pytest.mark.asyncio
async def test_get_branch(self, gitea_provider, mock_httpx_client):
"""Test getting branch info."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={
"name": "main",
"commit": {"sha": "abc123"},
}
)
result = await gitea_provider.get_branch("owner", "repo", "main")
assert result["name"] == "main"
@pytest.mark.asyncio
async def test_delete_remote_branch(self, gitea_provider, mock_httpx_client):
"""Test deleting a remote branch."""
mock_httpx_client.request.return_value.status_code = 204
result = await gitea_provider.delete_remote_branch("owner", "repo", "old-branch")
assert result is True
class TestGiteaCommentOperations:
"""Tests for Gitea comment operations."""
@pytest.mark.asyncio
async def test_add_pr_comment(self, gitea_provider, mock_httpx_client):
"""Test adding a comment to a PR."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"id": 1, "body": "Test comment"}
)
result = await gitea_provider.add_pr_comment(
"owner", "repo", 42, "Test comment"
)
assert result["body"] == "Test comment"
@pytest.mark.asyncio
async def test_list_pr_comments(self, gitea_provider, mock_httpx_client):
"""Test listing PR comments."""
mock_httpx_client.request.return_value.json = MagicMock(
return_value=[
{"id": 1, "body": "Comment 1"},
{"id": 2, "body": "Comment 2"},
]
)
result = await gitea_provider.list_pr_comments("owner", "repo", 42)
assert len(result) == 2
class TestGiteaLabelOperations:
"""Tests for Gitea label operations."""
@pytest.mark.asyncio
async def test_add_labels(self, gitea_provider, mock_httpx_client):
"""Test adding labels to a PR."""
# Use side_effect to return different values for different calls
# 1. GET labels (for "bug") - returns existing labels
# 2. POST to create "bug" label
# 3. GET labels (for "urgent")
# 4. POST to create "urgent" label
# 5. POST labels to PR
# 6. GET issue to return final labels
mock_responses = [
[{"id": 1, "name": "existing"}], # GET labels (bug not found)
{"id": 2, "name": "bug"}, # POST create bug
[{"id": 1, "name": "existing"}, {"id": 2, "name": "bug"}], # GET labels (urgent not found)
{"id": 3, "name": "urgent"}, # POST create urgent
{}, # POST add labels to PR
{"labels": [{"name": "bug"}, {"name": "urgent"}]}, # GET issue
]
mock_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
result = await gitea_provider.add_labels(
"owner", "repo", 42, ["bug", "urgent"]
)
# Should return updated label list
assert isinstance(result, list)
@pytest.mark.asyncio
async def test_remove_label(self, gitea_provider, mock_httpx_client):
"""Test removing a label from a PR."""
# Use side_effect for multiple calls
# 1. GET labels to find the label ID
# 2. DELETE the label from the PR
# 3. GET issue to return remaining labels
mock_responses = [
[{"id": 1, "name": "bug"}], # GET labels
{}, # DELETE label
{"labels": []}, # GET issue
]
mock_httpx_client.request.return_value.json = MagicMock(side_effect=mock_responses)
result = await gitea_provider.remove_label(
"owner", "repo", 42, "bug"
)
assert isinstance(result, list)
class TestGiteaReviewerOperations:
"""Tests for Gitea reviewer operations."""
@pytest.mark.asyncio
async def test_request_review(self, gitea_provider, mock_httpx_client):
"""Test requesting review from users."""
mock_httpx_client.request.return_value.json = MagicMock(return_value={})
result = await gitea_provider.request_review(
"owner", "repo", 42, ["reviewer1", "reviewer2"]
)
assert result == ["reviewer1", "reviewer2"]
class TestGiteaErrorHandling:
"""Tests for error handling in Gitea provider."""
@pytest.mark.asyncio
async def test_authentication_error(self, gitea_provider, mock_httpx_client):
"""Test handling authentication errors."""
mock_httpx_client.request.return_value.status_code = 401
with pytest.raises(AuthenticationError):
await gitea_provider._request("GET", "/user")
@pytest.mark.asyncio
async def test_permission_denied(self, gitea_provider, mock_httpx_client):
"""Test handling permission denied errors."""
mock_httpx_client.request.return_value.status_code = 403
with pytest.raises(AuthenticationError, match="Insufficient permissions"):
await gitea_provider._request("GET", "/protected")
@pytest.mark.asyncio
async def test_api_error(self, gitea_provider, mock_httpx_client):
"""Test handling general API errors."""
mock_httpx_client.request.return_value.status_code = 500
mock_httpx_client.request.return_value.text = "Internal Server Error"
mock_httpx_client.request.return_value.json = MagicMock(
return_value={"message": "Server error"}
)
with pytest.raises(APIError):
await gitea_provider._request("GET", "/error")
class TestGiteaPRParsing:
"""Tests for PR data parsing."""
def test_parse_pr_open(self, gitea_provider, sample_pr_data):
"""Test parsing open PR."""
pr_info = gitea_provider._parse_pr(sample_pr_data)
assert pr_info.number == 42
assert pr_info.state == PRState.OPEN
assert pr_info.title == "Test PR"
assert pr_info.source_branch == "feature-branch"
assert pr_info.target_branch == "main"
def test_parse_pr_merged(self, gitea_provider, sample_pr_data):
"""Test parsing merged PR."""
sample_pr_data["merged"] = True
sample_pr_data["merged_at"] = "2024-01-16T10:00:00Z"
pr_info = gitea_provider._parse_pr(sample_pr_data)
assert pr_info.state == PRState.MERGED
def test_parse_pr_closed(self, gitea_provider, sample_pr_data):
"""Test parsing closed PR."""
sample_pr_data["state"] = "closed"
sample_pr_data["closed_at"] = "2024-01-16T10:00:00Z"
pr_info = gitea_provider._parse_pr(sample_pr_data)
assert pr_info.state == PRState.CLOSED
def test_parse_datetime_iso(self, gitea_provider):
"""Test parsing ISO datetime strings."""
dt = gitea_provider._parse_datetime("2024-01-15T10:30:00Z")
assert dt.year == 2024
assert dt.month == 1
assert dt.day == 15
def test_parse_datetime_none(self, gitea_provider):
"""Test parsing None datetime returns now."""
dt = gitea_provider._parse_datetime(None)
assert dt is not None
assert dt.tzinfo is not None

View File

@@ -0,0 +1,514 @@
"""
Tests for the MCP server and tools.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from exceptions import ErrorCode
class TestInputValidation:
"""Tests for input validation functions."""
def test_validate_id_valid(self):
"""Test valid IDs pass validation."""
from server import _validate_id
assert _validate_id("test-123", "project_id") is None
assert _validate_id("my_project", "project_id") is None
assert _validate_id("Agent-001", "agent_id") is None
def test_validate_id_empty(self):
"""Test empty ID fails validation."""
from server import _validate_id
error = _validate_id("", "project_id")
assert error is not None
assert "required" in error.lower()
def test_validate_id_too_long(self):
"""Test too-long ID fails validation."""
from server import _validate_id
error = _validate_id("a" * 200, "project_id")
assert error is not None
assert "1-128" in error
def test_validate_id_invalid_chars(self):
"""Test invalid characters fail validation."""
from server import _validate_id
assert _validate_id("test@invalid", "project_id") is not None
assert _validate_id("test!project", "project_id") is not None
assert _validate_id("test project", "project_id") is not None
def test_validate_branch_valid(self):
"""Test valid branch names."""
from server import _validate_branch
assert _validate_branch("main") is None
assert _validate_branch("feature/new-thing") is None
assert _validate_branch("release-1.0.0") is None
assert _validate_branch("hotfix.urgent") is None
def test_validate_branch_invalid(self):
"""Test invalid branch names."""
from server import _validate_branch
assert _validate_branch("") is not None
assert _validate_branch("a" * 300) is not None
def test_validate_url_valid(self):
"""Test valid repository URLs."""
from server import _validate_url
assert _validate_url("https://github.com/owner/repo.git") is None
assert _validate_url("https://gitea.example.com/owner/repo") is None
assert _validate_url("git@github.com:owner/repo.git") is None
def test_validate_url_invalid(self):
"""Test invalid repository URLs."""
from server import _validate_url
assert _validate_url("") is not None
assert _validate_url("not-a-url") is not None
assert _validate_url("ftp://invalid.com/repo") is not None
class TestHealthCheck:
"""Tests for health check endpoint."""
@pytest.mark.asyncio
async def test_health_check_structure(self):
"""Test health check returns proper structure."""
from server import health_check
with patch("server._gitea_provider", None), \
patch("server._workspace_manager", None):
result = await health_check()
assert "status" in result
assert "service" in result
assert "version" in result
assert "timestamp" in result
assert "dependencies" in result
@pytest.mark.asyncio
async def test_health_check_no_providers(self):
"""Test health check without providers configured."""
from server import health_check
with patch("server._gitea_provider", None), \
patch("server._workspace_manager", None):
result = await health_check()
assert result["dependencies"]["gitea"] == "not configured"
class TestToolRegistry:
"""Tests for tool registration."""
def test_tool_registry_populated(self):
"""Test that tools are registered."""
from server import _tool_registry
assert len(_tool_registry) > 0
assert "clone_repository" in _tool_registry
assert "git_status" in _tool_registry
assert "create_branch" in _tool_registry
assert "commit" in _tool_registry
def test_tool_schema_structure(self):
"""Test tool schemas have proper structure."""
from server import _tool_registry
for name, info in _tool_registry.items():
assert "func" in info
assert "description" in info
assert "schema" in info
assert info["schema"]["type"] == "object"
assert "properties" in info["schema"]
class TestCloneRepository:
"""Tests for clone_repository tool."""
@pytest.mark.asyncio
async def test_clone_invalid_project_id(self):
"""Test clone with invalid project ID."""
from server import clone_repository
# Access the underlying function via .fn
result = await clone_repository.fn(
project_id="invalid@id",
agent_id="agent-1",
repo_url="https://github.com/owner/repo.git",
)
assert result["success"] is False
assert "project_id" in result["error"].lower()
@pytest.mark.asyncio
async def test_clone_invalid_repo_url(self):
"""Test clone with invalid repo URL."""
from server import clone_repository
result = await clone_repository.fn(
project_id="valid-project",
agent_id="agent-1",
repo_url="not-a-valid-url",
)
assert result["success"] is False
assert "url" in result["error"].lower()
class TestGitStatus:
"""Tests for git_status tool."""
@pytest.mark.asyncio
async def test_status_workspace_not_found(self):
"""Test status when workspace doesn't exist."""
from server import git_status
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await git_status.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
assert result["code"] == ErrorCode.WORKSPACE_NOT_FOUND.value
class TestBranchOperations:
"""Tests for branch operation tools."""
@pytest.mark.asyncio
async def test_create_branch_invalid_name(self):
"""Test creating branch with invalid name."""
from server import create_branch
result = await create_branch.fn(
project_id="test-project",
agent_id="agent-1",
branch_name="", # Invalid
)
assert result["success"] is False
@pytest.mark.asyncio
async def test_list_branches_workspace_not_found(self):
"""Test listing branches when workspace doesn't exist."""
from server import list_branches
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await list_branches.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
@pytest.mark.asyncio
async def test_checkout_invalid_project(self):
"""Test checkout with invalid project ID."""
from server import checkout
result = await checkout.fn(
project_id="inv@lid",
agent_id="agent-1",
ref="main",
)
assert result["success"] is False
class TestCommitOperations:
"""Tests for commit operation tools."""
@pytest.mark.asyncio
async def test_commit_invalid_project(self):
"""Test commit with invalid project ID."""
from server import commit
result = await commit.fn(
project_id="inv@lid",
agent_id="agent-1",
message="Test commit",
)
assert result["success"] is False
class TestPushPullOperations:
"""Tests for push/pull operation tools."""
@pytest.mark.asyncio
async def test_push_workspace_not_found(self):
"""Test push when workspace doesn't exist."""
from server import push
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await push.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
@pytest.mark.asyncio
async def test_pull_workspace_not_found(self):
"""Test pull when workspace doesn't exist."""
from server import pull
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await pull.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
class TestDiffLogOperations:
"""Tests for diff and log operation tools."""
@pytest.mark.asyncio
async def test_diff_workspace_not_found(self):
"""Test diff when workspace doesn't exist."""
from server import diff
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await diff.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
@pytest.mark.asyncio
async def test_log_workspace_not_found(self):
"""Test log when workspace doesn't exist."""
from server import log
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await log.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
class TestPROperations:
"""Tests for pull request operation tools."""
@pytest.mark.asyncio
async def test_create_pr_no_repo_url(self):
"""Test create PR when workspace has no repo URL."""
from models import WorkspaceInfo, WorkspaceState
from server import create_pull_request
mock_workspace = WorkspaceInfo(
project_id="test-project",
path="/tmp/test",
state=WorkspaceState.READY,
repo_url=None, # No repo URL
)
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
with patch("server._workspace_manager", mock_manager):
result = await create_pull_request.fn(
project_id="test-project",
agent_id="agent-1",
title="Test PR",
source_branch="feature",
target_branch="main",
)
assert result["success"] is False
assert "repository URL" in result["error"]
@pytest.mark.asyncio
async def test_list_prs_invalid_state(self):
"""Test list PRs with invalid state filter."""
from models import WorkspaceInfo, WorkspaceState
from server import list_pull_requests
mock_workspace = WorkspaceInfo(
project_id="test-project",
path="/tmp/test",
state=WorkspaceState.READY,
repo_url="https://gitea.test.com/owner/repo.git",
)
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
mock_provider = AsyncMock()
mock_provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
with patch("server._workspace_manager", mock_manager), \
patch("server._get_provider_for_url", return_value=mock_provider):
result = await list_pull_requests.fn(
project_id="test-project",
agent_id="agent-1",
state="invalid-state",
)
assert result["success"] is False
assert "Invalid state" in result["error"]
@pytest.mark.asyncio
async def test_merge_pr_invalid_strategy(self):
"""Test merge PR with invalid strategy."""
from models import WorkspaceInfo, WorkspaceState
from server import merge_pull_request
mock_workspace = WorkspaceInfo(
project_id="test-project",
path="/tmp/test",
state=WorkspaceState.READY,
repo_url="https://gitea.test.com/owner/repo.git",
)
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
mock_provider = AsyncMock()
mock_provider.parse_repo_url = MagicMock(return_value=("owner", "repo"))
with patch("server._workspace_manager", mock_manager), \
patch("server._get_provider_for_url", return_value=mock_provider):
result = await merge_pull_request.fn(
project_id="test-project",
agent_id="agent-1",
pr_number=42,
merge_strategy="invalid-strategy",
)
assert result["success"] is False
assert "Invalid strategy" in result["error"]
class TestWorkspaceOperations:
"""Tests for workspace operation tools."""
@pytest.mark.asyncio
async def test_get_workspace_not_found(self):
"""Test get workspace when it doesn't exist."""
from server import get_workspace
mock_manager = AsyncMock()
mock_manager.get_workspace = AsyncMock(return_value=None)
with patch("server._workspace_manager", mock_manager):
result = await get_workspace.fn(
project_id="nonexistent",
agent_id="agent-1",
)
assert result["success"] is False
@pytest.mark.asyncio
async def test_lock_workspace_success(self):
"""Test successful workspace locking."""
from datetime import UTC, datetime, timedelta
from models import WorkspaceInfo, WorkspaceState
from server import lock_workspace
mock_workspace = WorkspaceInfo(
project_id="test-project",
path="/tmp/test",
state=WorkspaceState.LOCKED,
lock_holder="agent-1",
lock_expires=datetime.now(UTC) + timedelta(seconds=300),
)
mock_manager = AsyncMock()
mock_manager.lock_workspace = AsyncMock(return_value=True)
mock_manager.get_workspace = AsyncMock(return_value=mock_workspace)
with patch("server._workspace_manager", mock_manager):
result = await lock_workspace.fn(
project_id="test-project",
agent_id="agent-1",
timeout=300,
)
assert result["success"] is True
assert result["lock_holder"] == "agent-1"
@pytest.mark.asyncio
async def test_unlock_workspace_success(self):
"""Test successful workspace unlocking."""
from server import unlock_workspace
mock_manager = AsyncMock()
mock_manager.unlock_workspace = AsyncMock(return_value=True)
with patch("server._workspace_manager", mock_manager):
result = await unlock_workspace.fn(
project_id="test-project",
agent_id="agent-1",
)
assert result["success"] is True
class TestJSONRPCEndpoint:
"""Tests for the JSON-RPC endpoint."""
def test_python_type_to_json_schema_str(self):
"""Test string type conversion."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(str)
assert result["type"] == "string"
def test_python_type_to_json_schema_int(self):
"""Test int type conversion."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(int)
assert result["type"] == "integer"
def test_python_type_to_json_schema_bool(self):
"""Test bool type conversion."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(bool)
assert result["type"] == "boolean"
def test_python_type_to_json_schema_list(self):
"""Test list type conversion."""
from server import _python_type_to_json_schema
result = _python_type_to_json_schema(list[str])
assert result["type"] == "array"
assert result["items"]["type"] == "string"

View File

@@ -0,0 +1,334 @@
"""
Tests for the workspace management module.
"""
from datetime import UTC, datetime, timedelta
from pathlib import Path
import pytest
from exceptions import WorkspaceLockedError, WorkspaceNotFoundError
from models import WorkspaceState
from workspace import FileLockManager, WorkspaceLock
class TestWorkspaceManager:
"""Tests for WorkspaceManager."""
@pytest.mark.asyncio
async def test_create_workspace(self, workspace_manager, valid_project_id):
"""Test creating a new workspace."""
workspace = await workspace_manager.create_workspace(valid_project_id)
assert workspace.project_id == valid_project_id
assert workspace.state == WorkspaceState.INITIALIZING
assert Path(workspace.path).exists()
@pytest.mark.asyncio
async def test_create_workspace_with_repo_url(self, workspace_manager, valid_project_id, sample_repo_url):
"""Test creating workspace with repository URL."""
workspace = await workspace_manager.create_workspace(
valid_project_id, repo_url=sample_repo_url
)
assert workspace.repo_url == sample_repo_url
@pytest.mark.asyncio
async def test_get_workspace(self, workspace_manager, valid_project_id):
"""Test getting an existing workspace."""
# Create first
await workspace_manager.create_workspace(valid_project_id)
# Get it
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace is not None
assert workspace.project_id == valid_project_id
@pytest.mark.asyncio
async def test_get_workspace_not_found(self, workspace_manager):
"""Test getting non-existent workspace."""
workspace = await workspace_manager.get_workspace("nonexistent")
assert workspace is None
@pytest.mark.asyncio
async def test_delete_workspace(self, workspace_manager, valid_project_id):
"""Test deleting a workspace."""
# Create first
workspace = await workspace_manager.create_workspace(valid_project_id)
workspace_path = Path(workspace.path)
assert workspace_path.exists()
# Delete
result = await workspace_manager.delete_workspace(valid_project_id)
assert result is True
assert not workspace_path.exists()
@pytest.mark.asyncio
async def test_delete_nonexistent_workspace(self, workspace_manager):
"""Test deleting non-existent workspace returns True."""
result = await workspace_manager.delete_workspace("nonexistent")
assert result is True
@pytest.mark.asyncio
async def test_list_workspaces(self, workspace_manager):
"""Test listing workspaces."""
# Create multiple workspaces
await workspace_manager.create_workspace("project-1")
await workspace_manager.create_workspace("project-2")
await workspace_manager.create_workspace("project-3")
workspaces = await workspace_manager.list_workspaces()
assert len(workspaces) >= 3
project_ids = [w.project_id for w in workspaces]
assert "project-1" in project_ids
assert "project-2" in project_ids
assert "project-3" in project_ids
class TestWorkspaceLocking:
"""Tests for workspace locking."""
@pytest.mark.asyncio
async def test_lock_workspace(self, workspace_manager, valid_project_id, valid_agent_id):
"""Test locking a workspace."""
await workspace_manager.create_workspace(valid_project_id)
result = await workspace_manager.lock_workspace(
valid_project_id, valid_agent_id, timeout=60
)
assert result is True
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.state == WorkspaceState.LOCKED
assert workspace.lock_holder == valid_agent_id
@pytest.mark.asyncio
async def test_lock_already_locked(self, workspace_manager, valid_project_id):
"""Test locking already-locked workspace by different holder."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, "agent-1", timeout=60)
with pytest.raises(WorkspaceLockedError):
await workspace_manager.lock_workspace(valid_project_id, "agent-2", timeout=60)
@pytest.mark.asyncio
async def test_lock_same_holder(self, workspace_manager, valid_project_id, valid_agent_id):
"""Test re-locking by same holder extends lock."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id, timeout=60)
# Same holder can re-lock
result = await workspace_manager.lock_workspace(
valid_project_id, valid_agent_id, timeout=120
)
assert result is True
@pytest.mark.asyncio
async def test_unlock_workspace(self, workspace_manager, valid_project_id, valid_agent_id):
"""Test unlocking a workspace."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
result = await workspace_manager.unlock_workspace(valid_project_id, valid_agent_id)
assert result is True
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.state == WorkspaceState.READY
assert workspace.lock_holder is None
@pytest.mark.asyncio
async def test_unlock_wrong_holder(self, workspace_manager, valid_project_id):
"""Test unlock fails with wrong holder."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, "agent-1")
with pytest.raises(WorkspaceLockedError):
await workspace_manager.unlock_workspace(valid_project_id, "agent-2")
@pytest.mark.asyncio
async def test_force_unlock(self, workspace_manager, valid_project_id):
"""Test force unlock works regardless of holder."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, "agent-1")
result = await workspace_manager.unlock_workspace(
valid_project_id, "admin", force=True
)
assert result is True
@pytest.mark.asyncio
async def test_lock_nonexistent_workspace(self, workspace_manager, valid_agent_id):
"""Test locking non-existent workspace raises error."""
with pytest.raises(WorkspaceNotFoundError):
await workspace_manager.lock_workspace("nonexistent", valid_agent_id)
class TestWorkspaceLockContextManager:
"""Tests for WorkspaceLock context manager."""
@pytest.mark.asyncio
async def test_lock_context_manager(self, workspace_manager, valid_project_id, valid_agent_id):
"""Test using WorkspaceLock as context manager."""
await workspace_manager.create_workspace(valid_project_id)
async with WorkspaceLock(
workspace_manager, valid_project_id, valid_agent_id
) as lock:
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.state == WorkspaceState.LOCKED
# After exiting context, should be unlocked
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.lock_holder is None
@pytest.mark.asyncio
async def test_lock_context_manager_error(self, workspace_manager, valid_project_id, valid_agent_id):
"""Test WorkspaceLock releases on exception."""
await workspace_manager.create_workspace(valid_project_id)
try:
async with WorkspaceLock(
workspace_manager, valid_project_id, valid_agent_id
):
raise ValueError("Test error")
except ValueError:
pass
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.lock_holder is None
class TestWorkspaceMetadata:
"""Tests for workspace metadata operations."""
@pytest.mark.asyncio
async def test_touch_workspace(self, workspace_manager, valid_project_id):
"""Test updating workspace access time."""
workspace = await workspace_manager.create_workspace(valid_project_id)
original_time = workspace.last_accessed
await workspace_manager.touch_workspace(valid_project_id)
updated = await workspace_manager.get_workspace(valid_project_id)
assert updated.last_accessed >= original_time
@pytest.mark.asyncio
async def test_update_workspace_branch(self, workspace_manager, valid_project_id):
"""Test updating workspace branch."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.update_workspace_branch(valid_project_id, "feature-branch")
workspace = await workspace_manager.get_workspace(valid_project_id)
assert workspace.current_branch == "feature-branch"
class TestWorkspaceSize:
"""Tests for workspace size management."""
@pytest.mark.asyncio
async def test_check_size_within_limit(self, workspace_manager, valid_project_id):
"""Test size check passes for small workspace."""
await workspace_manager.create_workspace(valid_project_id)
# Should not raise
result = await workspace_manager.check_size_limit(valid_project_id)
assert result is True
@pytest.mark.asyncio
async def test_get_total_size(self, workspace_manager, valid_project_id):
"""Test getting total workspace size."""
workspace = await workspace_manager.create_workspace(valid_project_id)
# Add some content
content_file = Path(workspace.path) / "content.txt"
content_file.write_text("x" * 1000)
total_size = await workspace_manager.get_total_size()
assert total_size >= 1000
class TestFileLockManager:
"""Tests for file-based locking."""
def test_acquire_lock(self, temp_dir):
"""Test acquiring a file lock."""
manager = FileLockManager(temp_dir / "locks")
result = manager.acquire("test-key")
assert result is True
# Cleanup
manager.release("test-key")
def test_release_lock(self, temp_dir):
"""Test releasing a file lock."""
manager = FileLockManager(temp_dir / "locks")
manager.acquire("test-key")
result = manager.release("test-key")
assert result is True
def test_is_locked(self, temp_dir):
"""Test checking if locked."""
manager = FileLockManager(temp_dir / "locks")
assert manager.is_locked("test-key") is False
manager.acquire("test-key")
assert manager.is_locked("test-key") is True
manager.release("test-key")
def test_release_nonexistent_lock(self, temp_dir):
"""Test releasing a lock that doesn't exist."""
manager = FileLockManager(temp_dir / "locks")
# Should not raise
result = manager.release("nonexistent")
assert result is False
class TestWorkspaceCleanup:
"""Tests for workspace cleanup operations."""
@pytest.mark.asyncio
async def test_cleanup_stale_workspaces(self, workspace_manager, test_settings):
"""Test cleaning up stale workspaces."""
# Create workspace
workspace = await workspace_manager.create_workspace("stale-project")
# Manually set it as stale by updating metadata
await workspace_manager._update_metadata(
"stale-project",
last_accessed=(datetime.now(UTC) - timedelta(days=30)).isoformat(),
)
# Run cleanup
cleaned = await workspace_manager.cleanup_stale_workspaces()
assert cleaned >= 1
@pytest.mark.asyncio
async def test_delete_locked_workspace_blocked(self, workspace_manager, valid_project_id, valid_agent_id):
"""Test deleting locked workspace is blocked without force."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
with pytest.raises(WorkspaceLockedError):
await workspace_manager.delete_workspace(valid_project_id)
@pytest.mark.asyncio
async def test_delete_locked_workspace_force(self, workspace_manager, valid_project_id, valid_agent_id):
"""Test force deleting locked workspace."""
await workspace_manager.create_workspace(valid_project_id)
await workspace_manager.lock_workspace(valid_project_id, valid_agent_id)
result = await workspace_manager.delete_workspace(valid_project_id, force=True)
assert result is True

1853
mcp-servers/git-ops/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,614 @@
"""
Workspace management for Git Operations MCP Server.
Handles isolated workspaces for each project, including creation,
locking, cleanup, and size management.
"""
import asyncio
import json
import logging
import shutil
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import Any
import aiofiles # type: ignore[import-untyped]
from filelock import FileLock, Timeout
from config import Settings, get_settings
from exceptions import (
WorkspaceLockedError,
WorkspaceNotFoundError,
WorkspaceSizeExceededError,
)
from models import WorkspaceInfo, WorkspaceState
logger = logging.getLogger(__name__)
# Metadata file name
WORKSPACE_METADATA_FILE = ".syndarix-workspace.json"
class WorkspaceManager:
"""
Manages git workspaces for projects.
Each project gets an isolated workspace directory for git operations.
Supports distributed locking via Redis or local file locks.
"""
def __init__(self, settings: Settings | None = None) -> None:
"""
Initialize WorkspaceManager.
Args:
settings: Optional settings override
"""
self.settings = settings or get_settings()
self.base_path = self.settings.workspace_base_path
self._ensure_base_path()
def _ensure_base_path(self) -> None:
"""Ensure the base workspace directory exists."""
self.base_path.mkdir(parents=True, exist_ok=True)
def _get_workspace_path(self, project_id: str) -> Path:
"""Get the path for a project workspace with path traversal protection."""
# Sanitize project ID for filesystem
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in project_id)
# Reject reserved names
reserved_names = {".", "..", "con", "prn", "aux", "nul"}
if safe_id.lower() in reserved_names:
raise ValueError(f"Invalid project ID: reserved name '{project_id}'")
# Construct path and verify it's within base_path (prevent path traversal)
workspace_path = (self.base_path / safe_id).resolve()
base_resolved = self.base_path.resolve()
if not workspace_path.is_relative_to(base_resolved):
raise ValueError(
f"Invalid project ID: path traversal detected '{project_id}'"
)
return workspace_path
def _get_lock_path(self, project_id: str) -> Path:
"""Get the lock file path for a workspace."""
return self._get_workspace_path(project_id) / ".lock"
def _get_metadata_path(self, project_id: str) -> Path:
"""Get the metadata file path for a workspace."""
return self._get_workspace_path(project_id) / WORKSPACE_METADATA_FILE
async def get_workspace(self, project_id: str) -> WorkspaceInfo | None:
"""
Get workspace info for a project.
Args:
project_id: Project identifier
Returns:
WorkspaceInfo or None if not found
"""
workspace_path = self._get_workspace_path(project_id)
if not workspace_path.exists():
return None
# Load metadata
metadata = await self._load_metadata(project_id)
# Calculate size
size_bytes = await self._calculate_size(workspace_path)
# Check lock status
lock_holder = None
lock_expires = None
if metadata:
lock_holder = metadata.get("lock_holder")
if metadata.get("lock_expires"):
lock_expires = datetime.fromisoformat(metadata["lock_expires"])
# Clear expired locks
if lock_expires < datetime.now(UTC):
lock_holder = None
lock_expires = None
# Determine state
state = WorkspaceState.READY
if lock_holder:
state = WorkspaceState.LOCKED
# Check if stale
last_accessed = datetime.now(UTC)
if metadata and metadata.get("last_accessed"):
last_accessed = datetime.fromisoformat(metadata["last_accessed"])
stale_threshold = datetime.now(UTC) - timedelta(
days=self.settings.workspace_stale_days
)
if last_accessed < stale_threshold:
state = WorkspaceState.STALE
return WorkspaceInfo(
project_id=project_id,
path=str(workspace_path),
state=state,
repo_url=metadata.get("repo_url") if metadata else None,
current_branch=metadata.get("current_branch") if metadata else None,
last_accessed=last_accessed,
size_bytes=size_bytes,
lock_holder=lock_holder,
lock_expires=lock_expires,
)
async def create_workspace(
self,
project_id: str,
repo_url: str | None = None,
) -> WorkspaceInfo:
"""
Create or get a workspace for a project.
Args:
project_id: Project identifier
repo_url: Optional repository URL
Returns:
WorkspaceInfo for the workspace
"""
workspace_path = self._get_workspace_path(project_id)
if workspace_path.exists():
# Workspace already exists, update metadata
await self._update_metadata(project_id, repo_url=repo_url)
workspace = await self.get_workspace(project_id)
if workspace:
return workspace
# Create workspace directory
workspace_path.mkdir(parents=True, exist_ok=True)
# Create initial metadata
metadata = {
"project_id": project_id,
"repo_url": repo_url,
"created_at": datetime.now(UTC).isoformat(),
"last_accessed": datetime.now(UTC).isoformat(),
}
await self._save_metadata(project_id, metadata)
return WorkspaceInfo(
project_id=project_id,
path=str(workspace_path),
state=WorkspaceState.INITIALIZING,
repo_url=repo_url,
last_accessed=datetime.now(UTC),
size_bytes=0,
)
async def delete_workspace(self, project_id: str, force: bool = False) -> bool:
"""
Delete a workspace.
Args:
project_id: Project identifier
force: Force delete even if locked
Returns:
True if deleted
"""
workspace_path = self._get_workspace_path(project_id)
if not workspace_path.exists():
return True
# Check lock
if not force:
workspace = await self.get_workspace(project_id)
if workspace and workspace.state == WorkspaceState.LOCKED:
raise WorkspaceLockedError(project_id, workspace.lock_holder)
try:
# Use shutil.rmtree for robust deletion
shutil.rmtree(workspace_path)
logger.info(f"Deleted workspace for project: {project_id}")
return True
except Exception as e:
logger.error(f"Failed to delete workspace {project_id}: {e}")
return False
async def lock_workspace(
self,
project_id: str,
holder: str,
timeout: int | None = None,
) -> bool:
"""
Acquire a lock on a workspace.
Args:
project_id: Project identifier
holder: Lock holder identifier (agent_id)
timeout: Lock timeout in seconds
Returns:
True if lock acquired
Raises:
WorkspaceNotFoundError: If workspace doesn't exist
WorkspaceLockedError: If already locked by another
"""
workspace = await self.get_workspace(project_id)
if workspace is None:
raise WorkspaceNotFoundError(project_id)
# Check if already locked by someone else
if workspace.state == WorkspaceState.LOCKED and workspace.lock_holder != holder:
# Check if lock expired
if workspace.lock_expires and workspace.lock_expires > datetime.now(UTC):
raise WorkspaceLockedError(project_id, workspace.lock_holder)
# Calculate lock expiry
lock_timeout = timeout or self.settings.workspace_lock_timeout
lock_expires = datetime.now(UTC) + timedelta(seconds=lock_timeout)
# Update metadata with lock info
await self._update_metadata(
project_id,
lock_holder=holder,
lock_expires=lock_expires.isoformat(),
)
logger.info(f"Workspace {project_id} locked by {holder}")
return True
async def unlock_workspace(
self,
project_id: str,
holder: str,
force: bool = False,
) -> bool:
"""
Release a lock on a workspace.
Args:
project_id: Project identifier
holder: Lock holder identifier
force: Force unlock regardless of holder
Returns:
True if unlocked
"""
workspace = await self.get_workspace(project_id)
if workspace is None:
raise WorkspaceNotFoundError(project_id)
# Verify holder
if not force and workspace.lock_holder and workspace.lock_holder != holder:
raise WorkspaceLockedError(project_id, workspace.lock_holder)
# Clear lock
await self._update_metadata(
project_id,
lock_holder=None,
lock_expires=None,
)
logger.info(f"Workspace {project_id} unlocked by {holder}")
return True
async def touch_workspace(self, project_id: str) -> None:
"""
Update last accessed time for a workspace.
Args:
project_id: Project identifier
"""
await self._update_metadata(
project_id,
last_accessed=datetime.now(UTC).isoformat(),
)
async def update_workspace_branch(
self,
project_id: str,
branch: str,
) -> None:
"""
Update the current branch in workspace metadata.
Args:
project_id: Project identifier
branch: Current branch name
"""
await self._update_metadata(
project_id,
current_branch=branch,
last_accessed=datetime.now(UTC).isoformat(),
)
async def check_size_limit(self, project_id: str) -> bool:
"""
Check if workspace exceeds size limit.
Args:
project_id: Project identifier
Returns:
True if within limits
Raises:
WorkspaceSizeExceededError: If size exceeds limit
"""
workspace_path = self._get_workspace_path(project_id)
if not workspace_path.exists():
return True
size_bytes = await self._calculate_size(workspace_path)
size_gb = size_bytes / (1024**3)
max_size_gb = self.settings.workspace_max_size_gb
if size_gb > max_size_gb:
raise WorkspaceSizeExceededError(project_id, size_gb, max_size_gb)
return True
async def list_workspaces(
self,
include_stale: bool = False,
) -> list[WorkspaceInfo]:
"""
List all workspaces.
Args:
include_stale: Include stale workspaces
Returns:
List of WorkspaceInfo
"""
workspaces: list[WorkspaceInfo] = []
if not self.base_path.exists():
return workspaces
for entry in self.base_path.iterdir():
if entry.is_dir() and not entry.name.startswith("."):
# Extract project_id from directory name
workspace = await self.get_workspace(entry.name)
if workspace:
if not include_stale and workspace.state == WorkspaceState.STALE:
continue
workspaces.append(workspace)
return workspaces
async def cleanup_stale_workspaces(self) -> int:
"""
Clean up stale workspaces.
Returns:
Number of workspaces cleaned up
"""
cleaned = 0
workspaces = await self.list_workspaces(include_stale=True)
for workspace in workspaces:
if workspace.state == WorkspaceState.STALE:
try:
await self.delete_workspace(workspace.project_id, force=True)
cleaned += 1
except Exception as e:
logger.error(
f"Failed to cleanup stale workspace {workspace.project_id}: {e}"
)
if cleaned > 0:
logger.info(f"Cleaned up {cleaned} stale workspaces")
return cleaned
async def get_total_size(self) -> int:
"""
Get total size of all workspaces.
Returns:
Total size in bytes
"""
return await self._calculate_size(self.base_path)
# Private methods
async def _load_metadata(self, project_id: str) -> dict[str, Any] | None:
"""Load workspace metadata from file."""
metadata_path = self._get_metadata_path(project_id)
if not metadata_path.exists():
return None
try:
async with aiofiles.open(metadata_path) as f:
content = await f.read()
return json.loads(content)
except Exception as e:
logger.warning(f"Failed to load metadata for {project_id}: {e}")
return None
async def _save_metadata(
self,
project_id: str,
metadata: dict[str, Any],
) -> None:
"""Save workspace metadata to file."""
metadata_path = self._get_metadata_path(project_id)
# Ensure parent directory exists
metadata_path.parent.mkdir(parents=True, exist_ok=True)
try:
async with aiofiles.open(metadata_path, "w") as f:
await f.write(json.dumps(metadata, indent=2))
except Exception as e:
logger.error(f"Failed to save metadata for {project_id}: {e}")
async def _update_metadata(
self,
project_id: str,
**updates: Any,
) -> None:
"""Update specific fields in workspace metadata."""
metadata = await self._load_metadata(project_id) or {}
# Handle None values (to clear fields)
for key, value in updates.items():
if value is None:
metadata.pop(key, None)
else:
metadata[key] = value
await self._save_metadata(project_id, metadata)
async def _calculate_size(self, path: Path) -> int:
"""Calculate total size of a directory."""
def _calc_size() -> int:
total = 0
try:
for entry in path.rglob("*"):
if entry.is_file():
try:
total += entry.stat().st_size
except OSError:
pass
except Exception:
pass
return total
# Run in executor for async compatibility
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, _calc_size)
class WorkspaceLock:
"""
Context manager for workspace locking.
Provides automatic locking/unlocking with proper cleanup.
"""
def __init__(
self,
manager: WorkspaceManager,
project_id: str,
holder: str,
timeout: int | None = None,
) -> None:
"""
Initialize workspace lock.
Args:
manager: WorkspaceManager instance
project_id: Project identifier
holder: Lock holder identifier
timeout: Lock timeout in seconds
"""
self.manager = manager
self.project_id = project_id
self.holder = holder
self.timeout = timeout
self._acquired = False
async def __aenter__(self) -> "WorkspaceLock":
"""Acquire lock on enter."""
await self.manager.lock_workspace(
self.project_id,
self.holder,
self.timeout,
)
self._acquired = True
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Release lock on exit."""
if self._acquired:
try:
await self.manager.unlock_workspace(
self.project_id,
self.holder,
)
except Exception as e:
logger.warning(f"Failed to release lock for {self.project_id}: {e}")
class FileLockManager:
"""
File-based locking for single-instance deployments.
Uses filelock for local locking when Redis is not available.
"""
def __init__(self, lock_dir: Path) -> None:
"""
Initialize file lock manager.
Args:
lock_dir: Directory for lock files
"""
self.lock_dir = lock_dir
self.lock_dir.mkdir(parents=True, exist_ok=True)
self._locks: dict[str, FileLock] = {}
def _get_lock(self, key: str) -> FileLock:
"""Get or create a file lock for a key."""
if key not in self._locks:
lock_path = self.lock_dir / f"{key}.lock"
self._locks[key] = FileLock(lock_path)
return self._locks[key]
def acquire(
self,
key: str,
timeout: float = 10.0,
) -> bool:
"""
Acquire a lock.
Args:
key: Lock key
timeout: Timeout in seconds
Returns:
True if acquired
"""
lock = self._get_lock(key)
try:
lock.acquire(timeout=timeout)
return True
except Timeout:
return False
def release(self, key: str) -> bool:
"""
Release a lock.
Args:
key: Lock key
Returns:
True if released
"""
if key in self._locks:
try:
self._locks[key].release()
return True
except Exception:
pass
return False
def is_locked(self, key: str) -> bool:
"""Check if a key is locked."""
lock = self._get_lock(key)
return lock.is_locked